diff --git a/.devcontainer/scripts/onCreateCommand.sh b/.devcontainer/scripts/onCreateCommand.sh index e4aaf8ace..eba12a357 100755 --- a/.devcontainer/scripts/onCreateCommand.sh +++ b/.devcontainer/scripts/onCreateCommand.sh @@ -1,5 +1,21 @@ #!/bin/bash +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + set -x az extension add --name spring diff --git a/.editorconfig b/.editorconfig index 3e1127a6d..8cbc8ccb4 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,3 +8,5 @@ indent_style = tab indent_size = 4 continuation_indent_size = 8 end_of_line = lf + +insert_final_newline = true diff --git a/.mvn/extensions.xml b/.mvn/extensions.xml index c7e6507ac..31675c589 100644 --- a/.mvn/extensions.xml +++ b/.mvn/extensions.xml @@ -1,4 +1,20 @@ + + fr.jcgay.maven diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties index dc3affce3..da1385eeb 100644 --- a/.mvn/wrapper/maven-wrapper.properties +++ b/.mvn/wrapper/maven-wrapper.properties @@ -1,18 +1,17 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.8.6/apache-maven-3.8.6-bin.zip wrapperUrl=https://repo.maven.apache.org/maven2/org/apache/maven/wrapper/maven-wrapper/3.1.1/maven-wrapper-3.1.1.jar diff --git a/document-readers/markdown-reader/pom.xml b/document-readers/markdown-reader/pom.xml index 5922ea2b4..9ad6aa6a1 100644 --- a/document-readers/markdown-reader/pom.xml +++ b/document-readers/markdown-reader/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java index 7ed8aa6b5..19ebed9ca 100644 --- a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java +++ b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/MarkdownDocumentReader.java @@ -1,18 +1,45 @@ -package org.springframework.ai.reader.markdown; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.commonmark.node.*; -import org.commonmark.parser.Parser; -import org.springframework.ai.document.Document; -import org.springframework.ai.document.DocumentReader; -import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; -import org.springframework.core.io.DefaultResourceLoader; -import org.springframework.core.io.Resource; +package org.springframework.ai.reader.markdown; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; +import org.commonmark.node.AbstractVisitor; +import org.commonmark.node.BlockQuote; +import org.commonmark.node.Code; +import org.commonmark.node.FencedCodeBlock; +import org.commonmark.node.HardLineBreak; +import org.commonmark.node.Heading; +import org.commonmark.node.ListItem; +import org.commonmark.node.Node; +import org.commonmark.node.SoftLineBreak; +import org.commonmark.node.Text; +import org.commonmark.node.ThematicBreak; +import org.commonmark.parser.Parser; + +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentReader; +import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; + /** * Reads the given Markdown resource and groups headers, paragraphs, or text divided by * horizontal lines (depending on the @@ -58,10 +85,10 @@ public class MarkdownDocumentReader implements DocumentReader { */ @Override public List get() { - try (var input = markdownResource.getInputStream()) { - Node node = parser.parseReader(new InputStreamReader(input)); + try (var input = this.markdownResource.getInputStream()) { + Node node = this.parser.parseReader(new InputStreamReader(input)); - DocumentVisitor documentVisitor = new DocumentVisitor(config); + DocumentVisitor documentVisitor = new DocumentVisitor(this.config); node.accept(documentVisitor); return documentVisitor.getDocuments(); @@ -90,7 +117,7 @@ public class MarkdownDocumentReader implements DocumentReader { @Override public void visit(org.commonmark.node.Document document) { - currentDocumentBuilder = Document.builder(); + this.currentDocumentBuilder = Document.builder(); super.visit(document); } @@ -102,7 +129,7 @@ public class MarkdownDocumentReader implements DocumentReader { @Override public void visit(ThematicBreak thematicBreak) { - if (config.horizontalRuleCreateDocument) { + if (this.config.horizontalRuleCreateDocument) { buildAndFlush(); } super.visit(thematicBreak); @@ -128,32 +155,32 @@ public class MarkdownDocumentReader implements DocumentReader { @Override public void visit(BlockQuote blockQuote) { - if (!config.includeBlockquote) { + if (!this.config.includeBlockquote) { buildAndFlush(); } translateLineBreakToSpace(); - currentDocumentBuilder.withMetadata("category", "blockquote"); + this.currentDocumentBuilder.withMetadata("category", "blockquote"); super.visit(blockQuote); } @Override public void visit(Code code) { - currentParagraphs.add(code.getLiteral()); - currentDocumentBuilder.withMetadata("category", "code_inline"); + this.currentParagraphs.add(code.getLiteral()); + this.currentDocumentBuilder.withMetadata("category", "code_inline"); super.visit(code); } @Override public void visit(FencedCodeBlock fencedCodeBlock) { - if (!config.includeCodeBlock) { + if (!this.config.includeCodeBlock) { buildAndFlush(); } translateLineBreakToSpace(); - currentParagraphs.add(fencedCodeBlock.getLiteral()); - currentDocumentBuilder.withMetadata("category", "code_block"); - currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo()); + this.currentParagraphs.add(fencedCodeBlock.getLiteral()); + this.currentDocumentBuilder.withMetadata("category", "code_block"); + this.currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo()); buildAndFlush(); @@ -163,11 +190,11 @@ public class MarkdownDocumentReader implements DocumentReader { @Override public void visit(Text text) { if (text.getParent() instanceof Heading heading) { - currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel())) + this.currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel())) .withMetadata("title", text.getLiteral()); } else { - currentParagraphs.add(text.getLiteral()); + this.currentParagraphs.add(text.getLiteral()); } super.visit(text); @@ -176,29 +203,29 @@ public class MarkdownDocumentReader implements DocumentReader { public List getDocuments() { buildAndFlush(); - return documents; + return this.documents; } private void buildAndFlush() { - if (!currentParagraphs.isEmpty()) { - String content = String.join("", currentParagraphs); + if (!this.currentParagraphs.isEmpty()) { + String content = String.join("", this.currentParagraphs); - Document.Builder builder = currentDocumentBuilder.withContent(content); + Document.Builder builder = this.currentDocumentBuilder.withContent(content); - config.additionalMetadata.forEach(builder::withMetadata); + this.config.additionalMetadata.forEach(builder::withMetadata); Document document = builder.build(); - documents.add(document); + this.documents.add(document); - currentParagraphs.clear(); + this.currentParagraphs.clear(); } - currentDocumentBuilder = Document.builder(); + this.currentDocumentBuilder = Document.builder(); } private void translateLineBreakToSpace() { - if (!currentParagraphs.isEmpty()) { - currentParagraphs.add(" "); + if (!this.currentParagraphs.isEmpty()) { + this.currentParagraphs.add(" "); } } diff --git a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java index d5ad3ec58..c22c573f0 100644 --- a/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java +++ b/document-readers/markdown-reader/src/main/java/org/springframework/ai/reader/markdown/config/MarkdownDocumentReaderConfig.java @@ -1,12 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.reader.markdown.config; +import java.util.HashMap; +import java.util.Map; + import org.springframework.ai.document.Document; import org.springframework.ai.reader.markdown.MarkdownDocumentReader; import org.springframework.util.Assert; -import java.util.HashMap; -import java.util.Map; - /** * Common configuration for the {@link MarkdownDocumentReader}. * @@ -23,10 +39,10 @@ public class MarkdownDocumentReaderConfig { public final Map additionalMetadata; public MarkdownDocumentReaderConfig(Builder builder) { - horizontalRuleCreateDocument = builder.horizontalRuleCreateDocument; - includeCodeBlock = builder.includeCodeBlock; - includeBlockquote = builder.includeBlockquote; - additionalMetadata = builder.additionalMetadata; + this.horizontalRuleCreateDocument = builder.horizontalRuleCreateDocument; + this.includeCodeBlock = builder.includeCodeBlock; + this.includeBlockquote = builder.includeBlockquote; + this.additionalMetadata = builder.additionalMetadata; } /** diff --git a/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java b/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java index 739dbbd70..69d3babe5 100644 --- a/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java +++ b/document-readers/markdown-reader/src/test/java/org/springframework/ai/reader/markdown/MarkdownDocumentReaderTest.java @@ -1,12 +1,29 @@ -package org.springframework.ai.reader.markdown; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.junit.jupiter.api.Test; -import org.springframework.ai.document.Document; -import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; +package org.springframework.ai.reader.markdown; import java.util.List; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.document.Document; +import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.groups.Tuple.tuple; diff --git a/document-readers/pdf-reader/pom.xml b/document-readers/pdf-reader/pom.xml index c870c9176..eace8bd6d 100644 --- a/document-readers/pdf-reader/pom.xml +++ b/document-readers/pdf-reader/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java index d1e95cd50..11fb99330 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; -import java.awt.Rectangle; +import java.awt.*; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -24,9 +25,9 @@ import java.util.stream.Collectors; import org.apache.pdfbox.pdfparser.PDFParser; import org.apache.pdfbox.pdmodel.PDDocument; import org.apache.pdfbox.pdmodel.PDPage; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig; @@ -46,22 +47,22 @@ import org.springframework.util.StringUtils; */ public class PagePdfDocumentReader implements DocumentReader { - private final Logger logger = LoggerFactory.getLogger(getClass()); - - private static final String PDF_PAGE_REGION = "pdfPageRegion"; - public static final String METADATA_START_PAGE_NUMBER = "page_number"; public static final String METADATA_END_PAGE_NUMBER = "end_page_number"; public static final String METADATA_FILE_NAME = "file_name"; + private static final String PDF_PAGE_REGION = "pdfPageRegion"; + protected final PDDocument document; - private PdfDocumentReaderConfig config; + private final Logger logger = LoggerFactory.getLogger(getClass()); protected String resourceFileName; + private PdfDocumentReaderConfig config; + public PagePdfDocumentReader(String resourceUrl) { this(new DefaultResourceLoader().getResource(resourceUrl)); } @@ -103,15 +104,15 @@ public class PagePdfDocumentReader implements DocumentReader { int totalPages = this.document.getDocumentCatalog().getPages().getCount(); int logFrequency = totalPages > 10 ? totalPages / 10 : 1; // if less than 10 - // pages, print - // each iteration + // pages, print + // each iteration int counter = 0; PDPage lastPage = this.document.getDocumentCatalog().getPages().iterator().next(); for (PDPage page : this.document.getDocumentCatalog().getPages()) { lastPage = page; if (counter % logFrequency == 0 && counter / logFrequency < 10) { - logger.info("Processing PDF page: {}", (counter + 1)); + this.logger.info("Processing PDF page: {}", (counter + 1)); } counter++; @@ -153,7 +154,7 @@ public class PagePdfDocumentReader implements DocumentReader { readDocuments.add(toDocument(lastPage, pageTextGroupList.stream().collect(Collectors.joining()), startPageNumber, pageNumber)); } - logger.info("Processing {} pages", totalPages); + this.logger.info("Processing {} pages", totalPages); return readDocuments; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java index 9f5d05530..a5943d45d 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; -import java.awt.Rectangle; +import java.awt.*; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import org.apache.pdfbox.pdfparser.PDFParser; import org.apache.pdfbox.pdmodel.PDDocument; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; import org.springframework.ai.reader.pdf.config.ParagraphManager; @@ -48,8 +49,6 @@ import org.springframework.util.StringUtils; */ public class ParagraphPdfDocumentReader implements DocumentReader { - private final Logger logger = LoggerFactory.getLogger(getClass()); - // Constants for metadata keys private static final String METADATA_START_PAGE = "page_number"; @@ -61,14 +60,16 @@ public class ParagraphPdfDocumentReader implements DocumentReader { private static final String METADATA_FILE_NAME = "file_name"; - private final ParagraphManager paragraphTextExtractor; - protected final PDDocument document; - private PdfDocumentReaderConfig config; + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ParagraphManager paragraphTextExtractor; protected String resourceFileName; + private PdfDocumentReaderConfig config; + /** * Constructs a ParagraphPdfDocumentReader using a resource URL. * @param resourceUrl The URL of the PDF resource. @@ -132,7 +133,7 @@ public class ParagraphPdfDocumentReader implements DocumentReader { List documents = new ArrayList<>(paragraphs.size()); if (!CollectionUtils.isEmpty(paragraphs)) { - logger.info("Start processing paragraphs from PDF"); + this.logger.info("Start processing paragraphs from PDF"); Iterator itr = paragraphs.iterator(); var current = itr.next(); @@ -151,7 +152,7 @@ public class ParagraphPdfDocumentReader implements DocumentReader { } } } - logger.info("End processing paragraphs from PDF"); + this.logger.info("End processing paragraphs from PDF"); return documents; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java index 0e2c7fbe9..ae5b8588f 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.aot; +import java.io.IOException; +import java.util.Set; + import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; -import java.io.IOException; -import java.util.Set; - /** * The PdfReaderRuntimeHints class is responsible for registering runtime hints for PDFBox * resources. diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java index 011880743..555b23fa0 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/ParagraphManager.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.config; import java.io.IOException; @@ -39,34 +40,6 @@ import org.springframework.util.CollectionUtils; */ public class ParagraphManager { - /** - * Represents a document paragraph metadata and hierarchy. - * - * @param parent Parent paragraph that will contain a children paragraphs. - * @param title Paragraph title as it appears in the PDF document. - * @param level The TOC deepness level for this paragraph. The root is at level 0. - * @param startPageNumber The page number in the PDF where this paragraph begins. - * @param endPageNumber The page number in the PDF where this paragraph ends. - * @param children Sub-paragraphs for this paragraph. - */ - public record Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, - int position, List children) { - - public Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, - int position) { - this(parent, title, level, startPageNumber, endPageNumber, position, new ArrayList<>()); - } - - @Override - public String toString() { - String indent = (level < 0) ? "" : new String(new char[level * 2]).replace('\0', ' '); - - return indent + " " + level + ") " + title + " [" + startPageNumber + "," + endPageNumber + "], children = " - + children.size() + ", pos = " + position; - } - - } - /** * Root of the paragraphs tree. */ @@ -90,7 +63,7 @@ public class ParagraphManager { new Paragraph(null, "root", -1, 1, this.document.getNumberOfPages(), 0), this.document.getDocumentCatalog().getDocumentOutline(), 0); - printParagraph(rootParagraph, System.out); + printParagraph(this.rootParagraph, System.out); } catch (Exception e) { throw new RuntimeException(e); @@ -203,4 +176,32 @@ public class ParagraphManager { return resultList; } + /** + * Represents a document paragraph metadata and hierarchy. + * + * @param parent Parent paragraph that will contain a children paragraphs. + * @param title Paragraph title as it appears in the PDF document. + * @param level The TOC deepness level for this paragraph. The root is at level 0. + * @param startPageNumber The page number in the PDF where this paragraph begins. + * @param endPageNumber The page number in the PDF where this paragraph ends. + * @param children Sub-paragraphs for this paragraph. + */ + public record Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, + int position, List children) { + + public Paragraph(Paragraph parent, String title, int level, int startPageNumber, int endPageNumber, + int position) { + this(parent, title, level, startPageNumber, endPageNumber, position, new ArrayList<>()); + } + + @Override + public String toString() { + String indent = (this.level < 0) ? "" : new String(new char[this.level * 2]).replace('\0', ' '); + + return indent + " " + this.level + ") " + this.title + " [" + this.startPageNumber + "," + + this.endPageNumber + "], children = " + this.children.size() + ", pos = " + this.position; + } + + } + } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java index 5a375b3d4..b80ff8e9b 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.config; import org.springframework.ai.reader.ExtractedTextFormatter; @@ -40,6 +41,14 @@ public class PdfDocumentReaderConfig { public final ExtractedTextFormatter pageExtractedTextFormatter; + private PdfDocumentReaderConfig(PdfDocumentReaderConfig.Builder builder) { + this.pagesPerDocument = builder.pagesPerDocument; + this.pageBottomMargin = builder.pageBottomMargin; + this.pageTopMargin = builder.pageTopMargin; + this.pageExtractedTextFormatter = builder.pageExtractedTextFormatter; + this.reversedParagraphPosition = builder.reversedParagraphPosition; + } + /** * Start building a new configuration. * @return The entry point for creating a new configuration. @@ -56,14 +65,6 @@ public class PdfDocumentReaderConfig { return builder().build(); } - private PdfDocumentReaderConfig(PdfDocumentReaderConfig.Builder builder) { - this.pagesPerDocument = builder.pagesPerDocument; - this.pageBottomMargin = builder.pageBottomMargin; - this.pageTopMargin = builder.pageTopMargin; - this.pageExtractedTextFormatter = builder.pageExtractedTextFormatter; - this.reversedParagraphPosition = builder.reversedParagraphPosition; - } - public static class Builder { private int pagesPerDocument = 1; diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java index 80e35acb3..ea1980ff6 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -180,8 +180,9 @@ public class ForkPDFLayoutTextStripper extends PDFTextStripper { double height = textPosition.getHeight(); int numberOfLines = (int) (Math.floor(textYPosition - previousTextYPosition) / height); numberOfLines = Math.max(1, numberOfLines - 1); // exclude current new line - if (DEBUG) + if (DEBUG) { System.out.println(height + " " + numberOfLines); + } return numberOfLines; } else { @@ -191,7 +192,7 @@ public class ForkPDFLayoutTextStripper extends PDFTextStripper { private TextLine addNewLine() { TextLine textLine = new TextLine(this.getCurrentPageWidth()); - textLineList.add(textLine); + this.textLineList.add(textLine); return textLine; } @@ -248,7 +249,7 @@ class TextLine { } public String getLine() { - return line; + return this.line; } private int computeIndexForCharacter(final Character character) { @@ -313,7 +314,7 @@ class TextLine { private void completeLineWithSpaces() { for (int i = 0; i < this.getLineLength(); ++i) { - line += SPACE_CHARACTER; + this.line += SPACE_CHARACTER; } } @@ -350,8 +351,9 @@ class Character { this.isFirstCharacterOfAWord = isFirstCharacterOfAWord; this.isCharacterAtTheBeginningOfNewLine = isCharacterAtTheBeginningOfNewLine; this.isCharacterCloseToPreviousWord = isCharacterPartOfASentence; - if (ForkPDFLayoutTextStripper.DEBUG) + if (ForkPDFLayoutTextStripper.DEBUG) { System.out.println(this.toString()); + } } public char getCharacterValue() { @@ -384,14 +386,14 @@ class Character { public String toString() { String toString = ""; - toString += index; + toString += this.index; toString += " "; - toString += characterValue; - toString += " isCharacterPartOfPreviousWord=" + isCharacterPartOfPreviousWord; - toString += " isFirstCharacterOfAWord=" + isFirstCharacterOfAWord; - toString += " isCharacterAtTheBeginningOfNewLine=" + isCharacterAtTheBeginningOfNewLine; - toString += " isCharacterPartOfASentence=" + isCharacterCloseToPreviousWord; - toString += " isCharacterCloseToPreviousWord=" + isCharacterCloseToPreviousWord; + toString += this.characterValue; + toString += " isCharacterPartOfPreviousWord=" + this.isCharacterPartOfPreviousWord; + toString += " isFirstCharacterOfAWord=" + this.isFirstCharacterOfAWord; + toString += " isCharacterAtTheBeginningOfNewLine=" + this.isCharacterAtTheBeginningOfNewLine; + toString += " isCharacterPartOfASentence=" + this.isCharacterCloseToPreviousWord; + toString += " isCharacterCloseToPreviousWord=" + this.isCharacterCloseToPreviousWord; return toString; } @@ -424,12 +426,12 @@ class CharacterFactory { this.isCharacterCloseToPreviousWord = this.isCharacterCloseToPreviousWord(textPosition); char character = this.getCharacterFromTextPosition(textPosition); int index = (int) textPosition.getX() / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; - return new Character(character, index, isCharacterPartOfPreviousWord, isFirstCharacterOfAWord, - isCharacterAtTheBeginningOfNewLine, isCharacterCloseToPreviousWord); + return new Character(character, index, this.isCharacterPartOfPreviousWord, this.isFirstCharacterOfAWord, + this.isCharacterAtTheBeginningOfNewLine, this.isCharacterCloseToPreviousWord); } private boolean isCharacterAtTheBeginningOfNewLine(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return true; } TextPosition previousTextPosition = this.getPreviousTextPosition(); @@ -438,18 +440,18 @@ class CharacterFactory { } private boolean isFirstCharacterOfAWord(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return true; } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); return (numberOfSpaces > 1) || this.isCharacterAtTheBeginningOfNewLine(textPosition); } private boolean isCharacterCloseToPreviousWord(final TextPosition textPosition) { - if (!firstCharacterOfLineFound) { + if (!this.firstCharacterOfLineFound) { return false; } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); return (numberOfSpaces > 1 && numberOfSpaces <= ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT); } @@ -485,4 +487,4 @@ class CharacterFactory { this.previousTextPosition = previousTextPosition; } -} \ No newline at end of file +} diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java index 44bcb511a..a5d39db89 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/PDFLayoutTextStripperByArea.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.layout; import java.awt.geom.Rectangle2D; @@ -70,8 +71,8 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { * java coordinates (y == 0 is top), not PDF coordinates (y == 0 is bottom). */ public void addRegion(String regionName, Rectangle2D rect) { - regions.add(regionName); - regionArea.put(regionName, rect); + this.regions.add(regionName); + this.regionArea.put(regionName, rect); } /** @@ -80,8 +81,8 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { * @param regionName The name of the region to delete. */ public void removeRegion(String regionName) { - regions.remove(regionName); - regionArea.remove(regionName); + this.regions.remove(regionName); + this.regionArea.remove(regionName); } /** @@ -89,7 +90,7 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { * @return A list of java.lang.String objects to identify the region names. */ public List getRegions() { - return regions; + return this.regions; } /** @@ -98,7 +99,7 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { * @return The text that was identified in that region. */ public String getTextForRegion(String regionName) { - StringWriter text = regionText.get(regionName); + StringWriter text = this.regionText.get(regionName); return text.toString(); } @@ -108,14 +109,14 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { * @throws IOException If there is an error while extracting text. */ public void extractRegions(PDPage page) throws IOException { - for (String regionName : regions) { + for (String regionName : this.regions) { setStartPage(getCurrentPageNo()); setEndPage(getCurrentPageNo()); // reset the stored text for the region so this class can be reused. ArrayList> regionCharactersByArticle = new ArrayList>(); regionCharactersByArticle.add(new ArrayList()); - regionCharacterList.put(regionName, regionCharactersByArticle); - regionText.put(regionName, new StringWriter()); + this.regionCharacterList.put(regionName, regionCharactersByArticle); + this.regionText.put(regionName, new StringWriter()); } if (page.hasContents()) { @@ -128,10 +129,10 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { */ @Override protected void processTextPosition(TextPosition text) { - for (Map.Entry regionAreaEntry : regionArea.entrySet()) { + for (Map.Entry regionAreaEntry : this.regionArea.entrySet()) { Rectangle2D rect = regionAreaEntry.getValue(); if (rect.contains(text.getX(), text.getY())) { - charactersByArticle = regionCharacterList.get(regionAreaEntry.getKey()); + this.charactersByArticle = this.regionCharacterList.get(regionAreaEntry.getKey()); super.processTextPosition(text); } } @@ -143,9 +144,9 @@ public class PDFLayoutTextStripperByArea extends ForkPDFLayoutTextStripper { */ @Override protected void writePage() throws IOException { - for (String region : regionArea.keySet()) { - charactersByArticle = regionCharacterList.get(region); - output = regionText.get(region); + for (String region : this.regionArea.keySet()) { + this.charactersByArticle = this.regionCharacterList.get(region); + this.output = this.regionText.get(region); super.writePage(); } } diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java index f42d7ef3d..71c230faf 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/PagePdfDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; import java.util.List; diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java index eec22054d..5b45f14de 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf; import org.junit.jupiter.api.Test; diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java index b7e0cd12e..c409abaa2 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.pdf.aot; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; + import org.springframework.aot.hint.RuntimeHints; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource; diff --git a/document-readers/tika-reader/pom.xml b/document-readers/tika-reader/pom.xml index 35abb98c6..59297edd7 100644 --- a/document-readers/tika-reader/pom.xml +++ b/document-readers/tika-reader/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java b/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java index f004cd197..1619e2bc9 100644 --- a/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java +++ b/document-readers/tika-reader/src/main/java/org/springframework/ai/reader/tika/TikaDocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.tika; import java.io.IOException; diff --git a/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java b/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java index 84a167ef1..5ae1e7a7d 100644 --- a/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java +++ b/document-readers/tika-reader/src/test/java/org/springframework/ai/reader/tika/TikaDocumentReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader.tika; import org.junit.jupiter.params.ParameterizedTest; diff --git a/models/spring-ai-anthropic/pom.xml b/models/spring-ai-anthropic/pom.xml index beacc8533..b24615394 100644 --- a/models/spring-ai-anthropic/pom.xml +++ b/models/spring-ai-anthropic/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 05ece850a..d67431a98 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import java.util.ArrayList; @@ -28,6 +29,9 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; @@ -42,7 +46,11 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -61,9 +69,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * The {@link ChatModel} implementation for the Anthropic service. * @@ -76,16 +81,21 @@ import reactor.core.publisher.Mono; */ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel { - private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class); - - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue(); public static final Integer DEFAULT_MAX_TOKENS = 500; public static final Double DEFAULT_TEMPERATURE = 0.8; + private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + /** + * The retry template used to retry the OpenAI API calls. + */ + public final RetryTemplate retryTemplate; + /** * The lower-level API for the Anthropic service. */ @@ -96,11 +106,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM */ private final AnthropicChatOptions defaultOptions; - /** - * The retry template used to retry the OpenAI API calls. - */ - public final RetryTemplate retryTemplate; - /** * Observation registry used for instrumentation. */ diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index c40539f8c..08ea7de8c 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import java.util.ArrayList; @@ -91,91 +92,24 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions return new Builder(); } - public static class Builder { - - private final AnthropicChatOptions options = new AnthropicChatOptions(); - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withModel(AnthropicApi.ChatModel model) { - this.options.model = model.getValue(); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withMetadata(ChatCompletionRequest.Metadata metadata) { - this.options.metadata = metadata; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.options.topK = topK; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public AnthropicChatOptions build() { - return this.options; - } - + public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withMetadata(fromOptions.getMetadata()) + .withStopSequences(fromOptions.getStopSequences()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -293,19 +227,86 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions return fromOptions(this); } - public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMetadata(fromOptions.getMetadata()) - .withStopSequences(fromOptions.getStopSequences()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + private final AnthropicChatOptions options = new AnthropicChatOptions(); + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withModel(AnthropicApi.ChatModel model) { + this.options.model = model.getValue(); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withMetadata(ChatCompletionRequest.Metadata metadata) { + this.options.metadata = metadata; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.options.topK = topK; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public AnthropicChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java index bf56e842b..71a47d1e0 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.aot; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 5d893ba6e..35fa4faf6 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api; import java.util.ArrayList; @@ -23,6 +24,14 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -38,15 +47,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * @author Christian Tzolov * @author Mariusz Bernacki @@ -57,12 +57,6 @@ public class AnthropicApi { public static final String PROVIDER_NAME = AiProvider.ANTHROPIC.value(); - private static final String HEADER_X_API_KEY = "x-api-key"; - - private static final String HEADER_ANTHROPIC_VERSION = "anthropic-version"; - - private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; - public static final String DEFAULT_BASE_URL = "https://api.anthropic.com"; public static final String DEFAULT_ANTHROPIC_VERSION = "2023-06-01"; @@ -71,10 +65,18 @@ public class AnthropicApi { public static final String BETA_MAX_TOKENS = "max-tokens-3-5-sonnet-2024-07-15"; + private static final String HEADER_X_API_KEY = "x-api-key"; + + private static final String HEADER_ANTHROPIC_VERSION = "anthropic-version"; + + private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; + private final StreamHelper streamHelper = new StreamHelper(); + private WebClient webClient; /** @@ -141,6 +143,74 @@ public class AnthropicApi { .build(); } + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletionResponse} as a body and HTTP + * status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/messages") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletionResponse.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + AtomicReference chatCompletionReference = new AtomicReference<>(); + + return this.webClient.post() + .uri("/v1/messages") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) + .filter(event -> event.type() != EventType.PING) + // Detect if the chunk is part of a streaming function call. + .map(event -> { + if (this.streamHelper.isToolUseStart(event)) { + isInsideTool.set(true); + } + return event; + }) + // Group all chunks belonging to the same function call. + .windowUntil(event -> { + if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + .concatMapIterable(window -> { + Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), + this.streamHelper::mergeToolUseEvents); + return List.of(monoChunk); + }) + .flatMap(mono -> mono) + .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) + .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); + } + /** * Check the Models * overview and models for @@ -257,6 +407,14 @@ public class AnthropicApi { this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null); } + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); + } + /** * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. @@ -265,15 +423,9 @@ public class AnthropicApi { */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { + } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); - } - - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } } public static class ChatCompletionRequestBuilder { @@ -378,12 +530,16 @@ public class AnthropicApi { } public ChatCompletionRequest build() { - return new ChatCompletionRequest(model, messages, system, maxTokens, metadata, stopSequences, stream, - temperature, topP, topK, tools); + return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata, + this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools); } } + /////////////////////////////////////// + /// ERROR EVENT + /////////////////////////////////////// + /** * Input messages. * @@ -535,9 +691,15 @@ public class AnthropicApi { public Source(String mediaType, String data) { this("base64", mediaType, data); } + } + } + /////////////////////////////////////// + /// CONTENT_BLOCK EVENTS + /////////////////////////////////////// + @JsonInclude(Include.NON_NULL) public record Tool(// @formatter:off @JsonProperty("name") String name, @@ -546,6 +708,8 @@ public class AnthropicApi { // @formatter:on } + // CB START EVENT + /** * @param id Unique object identifier. The format and length of IDs may change over * time. @@ -572,6 +736,8 @@ public class AnthropicApi { // @formatter:on } + // CB DELTA EVENT + /** * Usage statistics. * @@ -585,94 +751,7 @@ public class AnthropicApi { // @formatter:off } - - /////////////////////////////////////// - /// ERROR EVENT - /////////////////////////////////////// - - /** - * The evnt type of the streamed chunk. - */ - public enum EventType { - - /** - * Message start event. Contains a Message object with empty content. - */ - @JsonProperty("message_start") - MESSAGE_START, - - /** - * Message delta event, indicating top-level changes to the final Message object. - */ - @JsonProperty("message_delta") - MESSAGE_DELTA, - - /** - * A final message stop event. - */ - @JsonProperty("message_stop") - MESSAGE_STOP, - - /** - * - */ - @JsonProperty("content_block_start") - CONTENT_BLOCK_START, - - /** - * - */ - @JsonProperty("content_block_delta") - CONTENT_BLOCK_DELTA, - - /** - * - */ - @JsonProperty("content_block_stop") - CONTENT_BLOCK_STOP, - - /** - * - */ - @JsonProperty("error") - ERROR, - - /** - * - */ - @JsonProperty("ping") - PING, - - /** - * Artifically created event to aggregate tool use events. - */ - TOOL_USE_AGGREATE; - - } - - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", - visible = true) - @JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockStartEvent.class, name = "content_block_start"), - @JsonSubTypes.Type(value = ContentBlockDeltaEvent.class, name = "content_block_delta"), - @JsonSubTypes.Type(value = ContentBlockStopEvent.class, name = "content_block_stop"), - - @JsonSubTypes.Type(value = PingEvent.class, name = "ping"), - - @JsonSubTypes.Type(value = ErrorEvent.class, name = "error"), - - @JsonSubTypes.Type(value = MessageStartEvent.class, name = "message_start"), - @JsonSubTypes.Type(value = MessageDeltaEvent.class, name = "message_delta"), - @JsonSubTypes.Type(value = MessageStopEvent.class, name = "message_stop") }) - public interface StreamEvent { - - @JsonProperty("type") - EventType type(); - - } - - /////////////////////////////////////// - /// CONTENT_BLOCK EVENTS - /////////////////////////////////////// + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with @@ -736,13 +815,17 @@ public class AnthropicApi { @Override public String toString() { - return "EventToolUseBuilder [index=" + index + ", id=" + id + ", name=" + name + ", partialJson=" - + partialJson + ", toolUseMap=" + toolContentBlocks + "]"; + return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson=" + + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; } } - // CB START EVENT + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// + + // MESSAGE START EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockStartEvent(// @formatter:off @@ -773,7 +856,7 @@ public class AnthropicApi { } }// @formatter:on - // CB DELTA EVENT + // MESSAGE DELTA EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockDeltaEvent(// @formatter:off @@ -803,7 +886,7 @@ public class AnthropicApi { } }// @formatter:on - /// ECB STOP + // MESSAGE STOP EVENT @JsonInclude(Include.NON_NULL) public record ContentBlockStopEvent(// @formatter:off @@ -811,20 +894,12 @@ public class AnthropicApi { @JsonProperty("index") Integer index) implements StreamEvent { }// @formatter:on - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// - - // MESSAGE START EVENT - @JsonInclude(Include.NON_NULL) public record MessageStartEvent(// @formatter:off @JsonProperty("type") EventType type, @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { }// @formatter:on - // MESSAGE DELTA EVENT - @JsonInclude(Include.NON_NULL) public record MessageDeltaEvent(// @formatter:off @JsonProperty("type") EventType type, @@ -843,8 +918,6 @@ public class AnthropicApi { } }// @formatter:on - // MESSAGE STOP EVENT - @JsonInclude(Include.NON_NULL) public record MessageStopEvent(// @formatter:off @JsonProperty("type") EventType type) implements StreamEvent { @@ -873,74 +946,4 @@ public class AnthropicApi { @JsonProperty("type") EventType type) implements StreamEvent { }// @formatter:on - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletionResponse} as a body and HTTP - * status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/messages") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletionResponse.class); - } - - private final StreamHelper streamHelper = new StreamHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - AtomicReference chatCompletionReference = new AtomicReference<>(); - - return this.webClient.post() - .uri("/v1/messages") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) - .filter(event -> event.type() != EventType.PING) - // Detect if the chunk is part of a streaming function call. - .map(event -> { - if (this.streamHelper.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }) - // Group all chunks belonging to the same function call. - .windowUntil(event -> { - if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - .concatMapIterable(window -> { - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - this.streamHelper::mergeToolUseEvents); - return List.of(monoChunk); - }) - .flatMap(mono -> mono) - .map(event -> streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) - .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); - } - } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index 054bf023b..677bdb2e4 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api; import java.util.ArrayList; @@ -22,22 +23,22 @@ import java.util.concurrent.atomic.AtomicReference; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; -import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.api.AnthropicApi.Usage; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaJson; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaText; +import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockText; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockToolUse; import org.springframework.ai.anthropic.api.AnthropicApi.EventType; +import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent; +import org.springframework.ai.anthropic.api.AnthropicApi.Usage; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * Helper class to support streaming function calling. diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java index 0ed5cdde1..83edc7e15 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.metadata; import java.time.Duration; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java index 1de5edc6a..fbafc2297 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.metadata; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -27,10 +28,6 @@ import org.springframework.util.Assert; */ public class AnthropicUsage implements Usage { - public static AnthropicUsage from(AnthropicApi.Usage usage) { - return new AnthropicUsage(usage); - } - private final AnthropicApi.Usage usage; protected AnthropicUsage(AnthropicApi.Usage usage) { @@ -38,6 +35,10 @@ public class AnthropicUsage implements Usage { this.usage = usage; } + public static AnthropicUsage from(AnthropicApi.Usage usage) { + return new AnthropicUsage(usage); + } + protected AnthropicApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 1a816fc18..144a94537 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.io.IOException; import java.util.ArrayList; @@ -30,11 +29,12 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.tool.MockWeatherService; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; @@ -47,6 +47,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -59,7 +60,7 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = AnthropicChatModelIT.Config.class, properties = "spring.ai.retry.on-http-codes=429") @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") @@ -76,17 +77,25 @@ class AnthropicChatModelIT { @Value("classpath:/prompts/system-message.st") private Resource systemResource; + private static void validateChatResponseMetadata(ChatResponse response, String model) { + assertThat(response.getMetadata().getId()).isNotEmpty(); + assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20241022" }) void roleTest(String modelName) { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), AnthropicChatOptions.builder().withModel(modelName).build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); @@ -103,17 +112,17 @@ class AnthropicChatModelIT { void testMessageHistory() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), AnthropicChatOptions.builder().withModel("claude-3-sonnet-20240229").build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -167,16 +176,13 @@ class AnthropicChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = mapOutputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -189,7 +195,7 @@ class AnthropicChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -210,7 +216,7 @@ class AnthropicChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -234,7 +240,7 @@ class AnthropicChatModelIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("banan", "apple", "basket"); @@ -257,7 +263,7 @@ class AnthropicChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -284,7 +290,7 @@ class AnthropicChatModelIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -301,7 +307,7 @@ class AnthropicChatModelIT { void validateCallResponseMetadata() { String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -316,7 +322,7 @@ class AnthropicChatModelIT { void validateStreamCallResponseMetadata() { String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .stream() @@ -328,12 +334,8 @@ class AnthropicChatModelIT { validateChatResponseMetadata(response, model); } - private static void validateChatResponseMetadata(ChatResponse response, String model) { - assertThat(response.getMetadata().getId()).isNotEmpty(); - assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); - assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + record ActorsFilmsRecord(String actor, List movies) { + } @SpringBootConfiguration @@ -360,4 +362,4 @@ class AnthropicChatModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java index 17ef19c90..968e0c6a1 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.util.List; import java.util.stream.Collectors; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -39,9 +42,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link AnthropicChatModel}. @@ -61,7 +62,7 @@ public class AnthropicChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -77,7 +78,7 @@ public class AnthropicChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -99,7 +100,7 @@ public class AnthropicChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -121,7 +122,7 @@ public class AnthropicChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata, String finishReasons) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java index e92f4d670..e90a94f87 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java index be251b00d..3dec1f7bc 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java index d57bf765b..9cd11068b 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/EventParsingTests.java @@ -1,34 +1,35 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.anthropic; + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic; import java.io.IOException; import java.nio.charset.Charset; import java.util.List; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -44,6 +45,7 @@ public class EventParsingTests { .getContentAsString(Charset.defaultCharset()); List events = new ObjectMapper().readerFor(new TypeReference>() { + }).readValue(json); logger.info(events.toString()); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java index 11683d844..f38a6c8e6 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index ceaebdfe6..d83098077 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic.api; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic.api; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; @@ -28,7 +29,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -43,7 +44,7 @@ public class AnthropicApiIT { AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")), Role.USER); - ResponseEntity response = anthropicApi + ResponseEntity response = this.anthropicApi .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, false)); @@ -58,7 +59,7 @@ public class AnthropicApiIT { AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")), Role.USER); - Flux response = anthropicApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.anthropicApi.chatCompletionStream(new ChatCompletionRequest( AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, true)); assertThat(response).isNotNull(); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java index 6e9440e0a..0be31a138 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiLegacyToolIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.List; @@ -25,10 +26,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; -import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.tool.XmlHelper.FunctionCalls; import org.springframework.ai.anthropic.api.tool.XmlHelper.Tools; @@ -60,10 +61,6 @@ import static org.assertj.core.api.Assertions.assertThat; @SuppressWarnings("null") public class AnthropicApiLegacyToolIT { - private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); - - AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); - public static final String TOO_SYSTEM_PROMPT_TEMPLATE = """ In this environment you have access to a set of tools you can use to answer the user's question. @@ -84,9 +81,9 @@ public class AnthropicApiLegacyToolIT { public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); - static { - FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); - } + private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); + + AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); @Test void toolCalls() { @@ -120,7 +117,7 @@ public class AnthropicApiLegacyToolIT { private ResponseEntity doCall(ChatCompletionRequest chatCompletionRequest) { - ResponseEntity response = anthropicApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); FunctionCalls functionCalls = XmlHelper.extractFunctionCalls(response.getBody().content().get(0).text()); @@ -150,4 +147,8 @@ public class AnthropicApiLegacyToolIT { List.of(chatCompletionMessage2), null, 500, 0.8, false)); } + static { + FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java index c56429762..767d73d0d 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.ArrayList; @@ -26,11 +27,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; -import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.AnthropicApi.Tool; import org.springframework.ai.model.ModelOptionsUtils; @@ -53,16 +54,12 @@ import static org.assertj.core.api.Assertions.assertThat; @SuppressWarnings("null") public class AnthropicApiToolIT { + public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(AnthropicApiLegacyToolIT.class); AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); - public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); - - static { - FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); - } - List tools = List.of(new Tool("getCurrentWeather", "Get the weather in location. Return temperature in 30°F or 30°C format.", ModelOptionsUtils.jsonToMap(""" { @@ -109,10 +106,10 @@ public class AnthropicApiToolIT { .withMessages(messageConversation) .withMaxTokens(1500) .withTemperature(0.8) - .withTools(tools) + .withTools(this.tools) .build(); - ResponseEntity response = anthropicApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); List toolToUseList = response.getBody() .content() @@ -155,4 +152,8 @@ public class AnthropicApiToolIT { return doCall(messageConversation); } + static { + FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java index 762f60fd2..8af458298 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.anthropic.api.tool; import java.util.function.Function; @@ -28,14 +29,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); } /** @@ -63,27 +71,22 @@ public class MockWeatherService implements Function toolDescriptions) { - - public record ToolDescription( - @JsonProperty("tool_name") String toolName, - @JsonProperty("description") String description, - @JacksonXmlElementWrapper(localName = "parameters") @JsonProperty("parameter") List parameters) { - - @JacksonXmlRootElement(localName = "parameter") - public record Parameter( - @JsonProperty("name") String name, - @JsonProperty("type") String type, - @JsonProperty("description") String description) { - } - } - } // @formatter:on - - @JsonInclude(Include.NON_NULL) // @formatter:off - @JacksonXmlRootElement(localName = "function_calls") - public record FunctionCalls(@JsonProperty("invoke") Invoke invoke) { - public record Invoke( - @JsonProperty("tool_name") String toolName, - @JsonProperty("parameters") Map parameters) { - } - } // @formatter:on - - @JsonInclude(Include.NON_NULL) // @formatter:off - @JacksonXmlRootElement(localName = "function_results") - public record FunctionResults( - @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("result") List result) { - - public record Result( - @JsonProperty("tool_name") String toolName, - @JsonProperty("stdout") Object stdout) { - } - } // @formatter:on - public static String extractFunctionCallsXmlBlock(String text) { if (!StringUtils.hasText(text)) { return ""; @@ -149,4 +111,43 @@ public class XmlHelper { } + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "tools") + public record Tools( + @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("tool_description") List toolDescriptions) { + + public record ToolDescription( + @JsonProperty("tool_name") String toolName, + @JsonProperty("description") String description, + @JacksonXmlElementWrapper(localName = "parameters") @JsonProperty("parameter") List parameters) { + + @JacksonXmlRootElement(localName = "parameter") + public record Parameter( + @JsonProperty("name") String name, + @JsonProperty("type") String type, + @JsonProperty("description") String description) { + } + } + } // @formatter:on + + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "function_calls") + public record FunctionCalls(@JsonProperty("invoke") Invoke invoke) { + public record Invoke( + @JsonProperty("tool_name") String toolName, + @JsonProperty("parameters") Map parameters) { + } + } // @formatter:on + + @JsonInclude(Include.NON_NULL) // @formatter:off + @JacksonXmlRootElement(localName = "function_results") + public record FunctionResults( + @JacksonXmlElementWrapper(useWrapping = false) @JsonProperty("result") List result) { + + public record Result( + @JsonProperty("tool_name") String toolName, + @JsonProperty("stdout") Object stdout) { + } + } // @formatter:on + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index 7fd9a6e6c..d93a84e00 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.anthropic.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.anthropic.client; import java.io.IOException; import java.net.URL; @@ -31,6 +30,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.AnthropicTestConfiguration; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -51,7 +52,7 @@ import org.springframework.core.io.Resource; import org.springframework.test.context.ActiveProfiles; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429") @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") @@ -66,16 +67,13 @@ class AnthropicChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -91,7 +89,7 @@ class AnthropicChatClientIT { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -106,7 +104,7 @@ class AnthropicChatClientIT { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -123,7 +121,7 @@ class AnthropicChatClientIT { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -138,7 +136,7 @@ class AnthropicChatClientIT { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -153,7 +151,7 @@ class AnthropicChatClientIT { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -167,7 +165,7 @@ class AnthropicChatClientIT { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -184,7 +182,7 @@ class AnthropicChatClientIT { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .advisors(new SimpleLoggerAdvisor()) .user(u -> u @@ -211,7 +209,7 @@ class AnthropicChatClientIT { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .call() @@ -227,7 +225,7 @@ class AnthropicChatClientIT { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .build() @@ -245,7 +243,7 @@ class AnthropicChatClientIT { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -264,7 +262,7 @@ class AnthropicChatClientIT { void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) @@ -287,7 +285,7 @@ class AnthropicChatClientIT { URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0.0-SNAPSHOT/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to .options(AnthropicChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) @@ -304,7 +302,7 @@ class AnthropicChatClientIT { void streamingMultiModality() throws IOException { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET) .build()) .user(u -> u.text("Explain what do you see on this picture?") @@ -320,4 +318,8 @@ class AnthropicChatClientIT { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilms(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties b/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties index 8e8b3b2c3..4466a7180 100644 --- a/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties +++ b/models/spring-ai-anthropic/src/test/resources/application-logging-test.properties @@ -1 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + logging.level.org.springframework.ai.chat.client.advisor=DEBUG diff --git a/models/spring-ai-azure-openai/pom.xml b/models/spring-ai-azure-openai/pom.xml index 634435628..101b9e508 100644 --- a/models/spring-ai-azure-openai/pom.xml +++ b/models/spring-ai-azure-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java index 1d1e4afd9..314925b3a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.io.IOException; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.AudioTranscriptionFormat; import com.azure.ai.openai.models.AudioTranscriptionOptions; import com.azure.ai.openai.models.AudioTranscriptionTimestampGranularity; import com.azure.core.http.rest.Response; + import org.springframework.ai.audio.transcription.AudioTranscription; import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; @@ -35,9 +40,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.io.IOException; -import java.util.List; - /** * AzureOpenAI audio transcription client implementation for backed by * {@link OpenAIClient}. You provide as input the audio file you want to transcribe and @@ -61,6 +63,15 @@ public class AzureOpenAiAudioTranscriptionModel implements Model words = null; @@ -108,7 +119,7 @@ public class AzureOpenAiAudioTranscriptionModel implements Model audioTranscription = openAIClient.getAudioTranscriptionTextWithResponse( + Response audioTranscription = this.openAIClient.getAudioTranscriptionTextWithResponse( deploymentOrModelName, FILENAME_MARKER, audioTranscriptionOptions, null); String text = audioTranscription.getValue(); AudioTranscription transcript = new AudioTranscription(text); @@ -119,7 +130,7 @@ public class AzureOpenAiAudioTranscriptionModel implements Model getGranularityType() { + return this.granularityType; + } + + public void setGranularityType(List granularityType) { + this.granularityType = granularityType; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.prompt == null) ? 0 : this.prompt.hashCode()); + result = prime * result + ((this.language == null) ? 0 : this.language.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + AzureOpenAiAudioTranscriptionOptions other = (AzureOpenAiAudioTranscriptionOptions) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!this.model.equals(other.model)) + return false; + if (this.prompt == null) { + if (other.prompt != null) + return false; + } + else if (!this.prompt.equals(other.prompt)) + return false; + if (this.language == null) { + if (other.language != null) + return false; + } + else if (!this.language.equals(other.language)) + return false; + if (this.responseFormat == null) { + return other.responseFormat==null; + } + else return this.responseFormat.equals(other.responseFormat); + } + + public enum WhisperModel { + + // @formatter:off + @JsonProperty("whisper") WHISPER("whisper"); + // @formatter:on + + public final String value; + + WhisperModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + public enum TranscriptResponseFormat { + + // @formatter:off + @JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), + @JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class), + @JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class), + @JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), + @JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class); + + public final AudioTranscriptionFormat value; + + public final Class responseType; + + TranscriptResponseFormat(AudioTranscriptionFormat value, Class responseType) { + this.value = value; + this.responseType = responseType; + } + + public AudioTranscriptionFormat getValue() { + return this.value; + } + + public Class getResponseType() { + return this.responseType; + } + } + + public enum GranularityType { + + // @formatter:off + @JsonProperty("word") WORD(AudioTranscriptionTimestampGranularity.WORD), + @JsonProperty("segment") SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); + // @formatter:on + + public final AudioTranscriptionTimestampGranularity value; + + GranularityType(AudioTranscriptionTimestampGranularity value) { + this.value = value; + } + + public AudioTranscriptionTimestampGranularity getValue() { + return this.value; + } + + } + public static class Builder { protected AzureOpenAiAudioTranscriptionOptions options; @@ -114,134 +281,14 @@ public class AzureOpenAiAudioTranscriptionOptions implements AudioTranscriptionO } public AzureOpenAiAudioTranscriptionOptions build() { - Assert.hasText(options.model, "model must not be empty"); - Assert.notNull(options.responseFormat, "response_format must not be null"); + Assert.hasText(this.options.model, "model must not be empty"); + Assert.notNull(this.options.responseFormat, "response_format must not be null"); return this.options; } } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public String getDeploymentName() { - return deploymentName; - } - - public void setDeploymentName(String deploymentName) { - this.deploymentName = deploymentName; - } - - public String getLanguage() { - return this.language; - } - - public void setLanguage(String language) { - this.language = language; - } - - public String getPrompt() { - return this.prompt; - } - - public void setPrompt(String prompt) { - this.prompt = prompt; - } - - public Float getTemperature() { - return this.temperature; - } - - public void setTemperature(Float temperature) { - this.temperature = temperature; - } - - - public TranscriptResponseFormat getResponseFormat() { - return this.responseFormat; - } - - public void setResponseFormat(TranscriptResponseFormat responseFormat) { - this.responseFormat = responseFormat; - } - - public List getGranularityType() { - return this.granularityType; - } - - public void setGranularityType(List granularityType) { - this.granularityType = granularityType; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); - result = prime * result + ((language == null) ? 0 : language.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - AzureOpenAiAudioTranscriptionOptions other = (AzureOpenAiAudioTranscriptionOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.prompt == null) { - if (other.prompt != null) - return false; - } - else if (!this.prompt.equals(other.prompt)) - return false; - if (this.language == null) { - if (other.language != null) - return false; - } - else if (!this.language.equals(other.language)) - return false; - if (this.responseFormat == null) { - return other.responseFormat==null; - } - else return this.responseFormat.equals(other.responseFormat); - } - - public enum WhisperModel { - - // @formatter:off - @JsonProperty("whisper") WHISPER("whisper"); - // @formatter:on - - public final String value; - - WhisperModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - - } - /** * @param language The language of the transcribed text. * @param duration The duration of the audio in seconds. @@ -308,51 +355,6 @@ public class AzureOpenAiAudioTranscriptionOptions implements AudioTranscriptionO @JsonProperty("no_speech_prob") Float noSpeechProb) { // @formatter:on } - } - - public enum TranscriptResponseFormat { - - // @formatter:off - @JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), - @JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class), - @JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class), - @JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), - @JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class); - - public final AudioTranscriptionFormat value; - - public final Class responseType; - - TranscriptResponseFormat(AudioTranscriptionFormat value, Class responseType) { - this.value = value; - this.responseType = responseType; - } - - public AudioTranscriptionFormat getValue() { - return this.value; - } - - public Class getResponseType() { - return this.responseType; - } - } - - public enum GranularityType { - - // @formatter:off - @JsonProperty("word") WORD(AudioTranscriptionTimestampGranularity.WORD), - @JsonProperty("segment") SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); - // @formatter:on - - public final AudioTranscriptionTimestampGranularity value; - - GranularityType(AudioTranscriptionTimestampGranularity value) { - this.value = value; - } - - public AudioTranscriptionTimestampGranularity getValue() { - return this.value; - } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 2436c7d3d..77c2ea0b2 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,49 @@ package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.models.*; +import com.azure.ai.openai.models.ChatChoice; +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; +import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatCompletionsResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsToolCall; +import com.azure.ai.openai.models.ChatCompletionsToolDefinition; +import com.azure.ai.openai.models.ChatMessageContentItem; +import com.azure.ai.openai.models.ChatMessageImageContentItem; +import com.azure.ai.openai.models.ChatMessageImageUrl; +import com.azure.ai.openai.models.ChatMessageTextContentItem; +import com.azure.ai.openai.models.ChatRequestAssistantMessage; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestSystemMessage; +import com.azure.ai.openai.models.ChatRequestToolMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; +import com.azure.ai.openai.models.CompletionsFinishReason; +import com.azure.ai.openai.models.ContentFilterResultsForPrompt; +import com.azure.ai.openai.models.FunctionCall; +import com.azure.ai.openai.models.FunctionDefinition; import com.azure.core.util.BinaryData; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -54,20 +90,6 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.Base64; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; - /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -153,6 +175,19 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha this.observationRegistry = observationRegistry; } + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); + String id = chatCompletions.getId(); + Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); + return ChatResponseMetadata.builder() + .withId(id) + .withUsage(usage) + .withModel(chatCompletions.getModel()) + .withPromptMetadata(promptFilterMetadata) + .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) + .build(); + } + public AzureOpenAiChatOptions getDefaultOptions() { return AzureOpenAiChatOptions.fromOptions(this.defaultOptions); } @@ -302,19 +337,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha return new Generation(assistantMessage, generationMetadata); } - public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { - Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); - String id = chatCompletions.getId(); - Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); - return ChatResponseMetadata.builder() - .withId(id) - .withUsage(usage) - .withModel(chatCompletions.getModel()) - .withPromptMetadata(promptFilterMetadata) - .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) - .build(); - } - /** * Test access. */ @@ -332,8 +354,9 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha options = this.merge(options, this.defaultOptions); - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) + if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); + } if (prompt.getOptions() != null) { AzureOpenAiChatOptions updatedRuntimeOptions; @@ -428,14 +451,16 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha private String getMediaUrl(Media media) { Object data = media.getData(); - if (data instanceof String dataUrl) + if (data instanceof String dataUrl) { return dataUrl; + } else if (data instanceof byte[] dataBytes) { String base64EncodedData = Base64.getEncoder().encodeToString(dataBytes); return "data:" + media.getMimeType() + ";base64," + base64EncodedData; } - else + else { throw new IllegalArgumentException("Unknown media data type " + data.getClass().getName()); + } } private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 5685fa43e..f890f1266 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,18 +22,18 @@ import java.util.List; import java.util.Map; import java.util.Set; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.boot.context.properties.NestedConfigurationProperty; -import org.springframework.util.Assert; - import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + /** * The configuration information for a chat completions request. Completions support a * wide variety of tasks and generate text that continues from or "completes" provided @@ -206,129 +206,26 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio return new Builder(); } - public static class Builder { - - protected AzureOpenAiChatOptions options; - - public Builder() { - this.options = new AzureOpenAiChatOptions(); - } - - public Builder(AzureOpenAiChatOptions options) { - this.options = options; - } - - public Builder withDeploymentName(String deploymentName) { - this.options.deploymentName = deploymentName; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withLogitBias(Map logitBias) { - this.options.logitBias = logitBias; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withSeed(Long seed) { - this.options.seed = seed; - return this; - } - - public Builder withLogprobs(Boolean logprobs) { - this.options.logprobs = logprobs; - return this; - } - - public Builder withTopLogprobs(Integer topLogprobs) { - this.options.topLogProbs = topLogprobs; - return this; - } - - public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { - this.options.enhancements = enhancements; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public AzureOpenAiChatOptions build() { - return this.options; - } - + public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) { + return builder().withDeploymentName(fromOptions.getDeploymentName()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null) + .withLogitBias(fromOptions.getLogitBias()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withN(fromOptions.getN()) + .withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withUser(fromOptions.getUser()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withLogprobs(fromOptions.isLogprobs()) + .withTopLogprobs(fromOptions.getTopLogProbs()) + .withEnhancements(fromOptions.getEnhancements()) + .withToolContext(fromOptions.getToolContext()) + .build(); } @Override @@ -526,26 +423,129 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio return fromOptions(this); } - public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) { - return builder().withDeploymentName(fromOptions.getDeploymentName()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null) - .withLogitBias(fromOptions.getLogitBias()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withUser(fromOptions.getUser()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withSeed(fromOptions.getSeed()) - .withLogprobs(fromOptions.isLogprobs()) - .withTopLogprobs(fromOptions.getTopLogProbs()) - .withEnhancements(fromOptions.getEnhancements()) - .withToolContext(fromOptions.getToolContext()) - .build(); + public static class Builder { + + protected AzureOpenAiChatOptions options; + + public Builder() { + this.options = new AzureOpenAiChatOptions(); + } + + public Builder(AzureOpenAiChatOptions options) { + this.options = options; + } + + public Builder withDeploymentName(String deploymentName) { + this.options.deploymentName = deploymentName; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogitBias(Map logitBias) { + this.options.logitBias = logitBias; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withSeed(Long seed) { + this.options.seed = seed; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + this.options.topLogProbs = topLogprobs; + return this; + } + + public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { + this.options.enhancements = enhancements; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public AzureOpenAiChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index 178275851..c7ca01b22 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; import com.azure.ai.openai.models.EmbeddingsOptions; - import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; @@ -41,9 +44,6 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Azure Open AI Embedding Model implementation. * @@ -56,14 +56,14 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final OpenAIClient azureOpenAiClient; private final AzureOpenAiEmbeddingOptions defaultOptions; private final MetadataMode metadataMode; - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - /** * Observation registry used for instrumentation. */ diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java index 7713f95f6..e2e8f3e24 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; import com.fasterxml.jackson.annotation.JsonIgnore; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -58,6 +60,61 @@ public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + @Override + @JsonIgnore + public String getModel() { + return getDeploymentName(); + } + + @JsonIgnore + public void setModel(String model) { + setDeploymentName(model); + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public String getDeploymentName() { + return this.deploymentName; + } + + public void setDeploymentName(String deploymentName) { + this.deploymentName = deploymentName; + } + + public String getInputType() { + return this.inputType; + } + + public void setInputType(String inputType) { + this.inputType = inputType; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public com.azure.ai.openai.models.EmbeddingsOptions toAzureOptions(List instructions) { + + var azureOptions = new com.azure.ai.openai.models.EmbeddingsOptions(instructions); + azureOptions.setModel(this.getDeploymentName()); + azureOptions.setUser(this.getUser()); + azureOptions.setInputType(this.getInputType()); + azureOptions.setDimensions(this.getDimensions()); + + return azureOptions; + } + public static class Builder { private final AzureOpenAiEmbeddingOptions options = new AzureOpenAiEmbeddingOptions(); @@ -125,59 +182,4 @@ public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions { } - @Override - @JsonIgnore - public String getModel() { - return getDeploymentName(); - } - - @JsonIgnore - public void setModel(String model) { - setDeploymentName(model); - } - - public String getUser() { - return this.user; - } - - public void setUser(String user) { - this.user = user; - } - - public String getDeploymentName() { - return this.deploymentName; - } - - public void setDeploymentName(String deploymentName) { - this.deploymentName = deploymentName; - } - - public String getInputType() { - return this.inputType; - } - - public void setInputType(String inputType) { - this.inputType = inputType; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public com.azure.ai.openai.models.EmbeddingsOptions toAzureOptions(List instructions) { - - var azureOptions = new com.azure.ai.openai.models.EmbeddingsOptions(instructions); - azureOptions.setModel(this.getDeploymentName()); - azureOptions.setUser(this.getUser()); - azureOptions.setInputType(this.getInputType()); - azureOptions.setDimensions(this.getDimensions()); - - return azureOptions; - } - } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index 7d2b3dae3..9b1b466ef 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -1,5 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; +import java.util.List; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageGenerationQuality; @@ -13,6 +31,7 @@ import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.json.JsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiImageGenerationMetadata; import org.springframework.ai.azure.openai.metadata.AzureOpenAiImageResponseMetadata; import org.springframework.ai.image.Image; @@ -25,8 +44,6 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.util.JacksonUtils; import org.springframework.util.Assert; -import java.util.List; - import static java.lang.String.format; /** @@ -68,22 +85,22 @@ public class AzureOpenAiImageModel implements ImageModel { } public AzureOpenAiImageOptions getDefaultOptions() { - return defaultOptions; + return this.defaultOptions; } @Override public ImageResponse call(ImagePrompt imagePrompt) { ImageGenerationOptions imageGenerationOptions = toOpenAiImageOptions(imagePrompt); String deploymentOrModelName = getDeploymentName(imagePrompt); - if (logger.isTraceEnabled()) { - logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", deploymentOrModelName, - toPrettyJson(imageGenerationOptions)); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", + deploymentOrModelName, toPrettyJson(imageGenerationOptions)); } - var images = openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); + var images = this.openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); - if (logger.isTraceEnabled()) { - logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); } List imageGenerations = images.getData().stream().map(entry -> { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index be15fbfd1..2e6d13c57 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -1,12 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; import java.util.Objects; import com.fasterxml.jackson.annotation.JsonInclude; -import org.springframework.ai.image.ImageOptions; - import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.image.ImageOptions; + /** * The configuration information for a image generation request. * @@ -89,9 +105,13 @@ public class AzureOpenAiImageOptions implements ImageOptions { @JsonProperty("user") private String user; + public static Builder builder() { + return new Builder(); + } + @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -100,7 +120,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -109,7 +129,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -119,7 +139,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -129,7 +149,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -148,7 +168,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { } public String getUser() { - return user; + return this.user; } public void setUser(String user) { @@ -156,7 +176,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { } public String getQuality() { - return quality; + return this.quality; } public void setQuality(String quality) { @@ -165,7 +185,7 @@ public class AzureOpenAiImageOptions implements ImageOptions { @Override public String getStyle() { - return style; + return this.style; } public void setStyle(String style) { @@ -173,95 +193,40 @@ public class AzureOpenAiImageOptions implements ImageOptions { } public String getDeploymentName() { - return deploymentName; + return this.deploymentName; } public void setDeploymentName(String deploymentName) { this.deploymentName = deploymentName; } - public static Builder builder() { - return new Builder(); - } - @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageOptions that)) + } + if (!(o instanceof AzureOpenAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) - && Objects.equals(deploymentName, that.deploymentName) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(quality, that.quality) - && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.deploymentName, that.deploymentName) && Objects.equals(this.width, that.width) + && Objects.equals(this.height, that.height) && Objects.equals(this.quality, that.quality) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.size, that.size) + && Objects.equals(this.style, that.style) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, deploymentName, width, height, quality, responseFormat, size, style, user); + return Objects.hash(this.n, this.model, this.deploymentName, this.width, this.height, this.quality, + this.responseFormat, this.size, this.style, this.user); } @Override public String toString() { - return "AzureOpenAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", deploymentName='" - + deploymentName + '\'' + ", width=" + width + ", height=" + height + ", quality='" + quality + '\'' - + ", responseFormat='" + responseFormat + '\'' + ", size='" + size + '\'' + ", style='" + style + '\'' - + ", user='" + user + '\'' + '}'; - } - - public static class Builder { - - private final AzureOpenAiImageOptions options; - - private Builder() { - this.options = new AzureOpenAiImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withDeploymentName(String deploymentName) { - options.setDeploymentName(deploymentName); - return this; - } - - public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public AzureOpenAiImageOptions build() { - return options; - } - - public Builder withStyle(String style) { - options.setStyle(style); - return this; - } - + return "AzureOpenAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", deploymentName='" + + this.deploymentName + '\'' + ", width=" + this.width + ", height=" + this.height + ", quality='" + + this.quality + '\'' + ", responseFormat='" + this.responseFormat + '\'' + ", size='" + this.size + + '\'' + ", style='" + this.style + '\'' + ", user='" + this.user + '\'' + '}'; } public enum ImageModel { @@ -290,4 +255,58 @@ public class AzureOpenAiImageOptions implements ImageOptions { } + public static class Builder { + + private final AzureOpenAiImageOptions options; + + private Builder() { + this.options = new AzureOpenAiImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withDeploymentName(String deploymentName) { + this.options.setDeploymentName(deploymentName); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public AzureOpenAiImageOptions build() { + return this.options; + } + + public Builder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + } + } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java index 31bcb7458..fd83532ec 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; /** diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java index 141181768..82c1f57b5 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.lang.reflect.Constructor; @@ -49,6 +50,15 @@ import org.springframework.util.CollectionUtils; */ public class MergeUtils { + private static final Class[] CHAT_COMPLETIONS_CONSTRUCTOR_ARG_TYPES = new Class[] { String.class, + OffsetDateTime.class, List.class, CompletionsUsage.class }; + + private static final Class[] chatChoiceConstructorArgumentTypes = new Class[] { + ChatChoiceLogProbabilityInfo.class, int.class, CompletionsFinishReason.class }; + + private static final Class[] chatResponseMessageConstructorArgumentTypes = new Class[] { ChatRole.class, + String.class }; + /** * Create a new instance of the given class using the constructor at the given index. * Can be used to create instances with private constructors. @@ -106,9 +116,6 @@ public class MergeUtils { return chatCompletionsInstance; } - private static final Class[] CHAT_COMPLETIONS_CONSTRUCTOR_ARG_TYPES = new Class[] { String.class, - OffsetDateTime.class, List.class, CompletionsUsage.class }; - /** * Merge two ChatCompletions instances into a single ChatCompletions instance. * @param left the left ChatCompletions instance. @@ -158,9 +165,6 @@ public class MergeUtils { return instance; } - private static final Class[] chatChoiceConstructorArgumentTypes = new Class[] { - ChatChoiceLogProbabilityInfo.class, int.class, CompletionsFinishReason.class }; - /** * Merge two ChatChoice instances into a single ChatChoice instance. * @param left the left ChatChoice instance to merge. @@ -211,9 +215,6 @@ public class MergeUtils { return instance; } - private static final Class[] chatResponseMessageConstructorArgumentTypes = new Class[] { ChatRole.class, - String.class }; - /** * Merge two ChatResponseMessage instances into a single ChatResponseMessage instance. * @param left the left ChatResponseMessage instance to merge. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java index 488870bcb..75ba720b0 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.aot; import com.azure.ai.openai.OpenAIAsyncClient; diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java index f64a805a1..a55ecd604 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiAudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import org.springframework.ai.audio.transcription.AudioTranscriptionResponseMetadata; @@ -26,10 +27,14 @@ import org.springframework.util.Assert; */ public class AzureOpenAiAudioTranscriptionResponseMetadata extends AudioTranscriptionResponseMetadata { + public static final AzureOpenAiAudioTranscriptionResponseMetadata NULL = new AzureOpenAiAudioTranscriptionResponseMetadata() { + + }; + protected static final String AI_METADATA_STRING = "{ @type: %1$s }"; - public static final AzureOpenAiAudioTranscriptionResponseMetadata NULL = new AzureOpenAiAudioTranscriptionResponseMetadata() { - }; + protected AzureOpenAiAudioTranscriptionResponseMetadata() { + } public static AzureOpenAiAudioTranscriptionResponseMetadata from( AzureOpenAiAudioTranscriptionOptions.StructuredResponse result) { @@ -42,9 +47,6 @@ public class AzureOpenAiAudioTranscriptionResponseMetadata extends AudioTranscri return new AzureOpenAiAudioTranscriptionResponseMetadata(); } - protected AzureOpenAiAudioTranscriptionResponseMetadata() { - } - @Override public String toString() { return AI_METADATA_STRING.formatted(getClass().getName()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java index 8ec132871..8fe0fa1e4 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import com.azure.ai.openai.models.EmbeddingsUsage; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; @@ -27,11 +29,6 @@ import org.springframework.util.Assert; */ public class AzureOpenAiEmbeddingUsage implements Usage { - public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) { - Assert.notNull(usage, "EmbeddingsUsage must not be null"); - return new AzureOpenAiEmbeddingUsage(usage); - } - private final EmbeddingsUsage usage; public AzureOpenAiEmbeddingUsage(EmbeddingsUsage usage) { @@ -39,6 +36,11 @@ public class AzureOpenAiEmbeddingUsage implements Usage { this.usage = usage; } + public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) { + Assert.notNull(usage, "EmbeddingsUsage must not be null"); + return new AzureOpenAiEmbeddingUsage(usage); + } + protected EmbeddingsUsage getUsage() { return this.usage; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java index 44b429e9f..eecc94ef7 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageGenerationMetadata.java @@ -1,9 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai.metadata; -import org.springframework.ai.image.ImageGenerationMetadata; - import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + /** * Represents the metadata for image generation using Azure OpenAI. * @@ -19,25 +35,27 @@ public class AzureOpenAiImageGenerationMetadata implements ImageGenerationMetada } public String getRevisedPrompt() { - return revisedPrompt; + return this.revisedPrompt; } public String toString() { - return "AzureOpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "AzureOpenAiImageGenerationMetadata{" + "revisedPrompt='" + this.revisedPrompt + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageGenerationMetadata that)) + } + if (!(o instanceof AzureOpenAiImageGenerationMetadata that)) { return false; - return Objects.equals(revisedPrompt, that.revisedPrompt); + } + return Objects.equals(this.revisedPrompt, that.revisedPrompt); } @Override public int hashCode() { - return Objects.hash(revisedPrompt); + return Objects.hash(this.revisedPrompt); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java index 6d01d5cbb..cdc24d0ab 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiImageResponseMetadata.java @@ -1,13 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai.metadata; -import com.azure.ai.openai.models.ImageGenerations; -import org.springframework.ai.image.ImageResponseMetadata; -import org.springframework.ai.model.MutableResponseMetadata; -import org.springframework.util.Assert; - -import java.util.HashMap; import java.util.Objects; +import com.azure.ai.openai.models.ImageGenerations; + +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.util.Assert; + /** * Represents metadata associated with an image response from the Azure OpenAI image * model. It provides additional information about the generative response from the Azure @@ -20,15 +35,15 @@ public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata { private final Long created; + protected AzureOpenAiImageResponseMetadata(Long created) { + this.created = created; + } + public static AzureOpenAiImageResponseMetadata from(ImageGenerations openAiImageResponse) { Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null"); return new AzureOpenAiImageResponseMetadata(openAiImageResponse.getCreatedAt().toEpochSecond()); } - protected AzureOpenAiImageResponseMetadata(Long created) { - this.created = created; - } - @Override public Long getCreated() { return this.created; @@ -36,21 +51,23 @@ public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata { @Override public String toString() { - return "AzureOpenAiImageResponseMetadata{" + "created=" + created + '}'; + return "AzureOpenAiImageResponseMetadata{" + "created=" + this.created + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AzureOpenAiImageResponseMetadata that)) + } + if (!(o instanceof AzureOpenAiImageResponseMetadata that)) { return false; - return Objects.equals(created, that.created); + } + return Objects.equals(this.created, that.created); } @Override public int hashCode() { - return Objects.hash(created); + return Objects.hash(this.created); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java index 056d44eb0..b0dd15d13 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.metadata; import com.azure.ai.openai.models.ChatCompletions; @@ -30,6 +31,13 @@ import org.springframework.util.Assert; */ public class AzureOpenAiUsage implements Usage { + private final CompletionsUsage usage; + + public AzureOpenAiUsage(CompletionsUsage usage) { + Assert.notNull(usage, "CompletionsUsage must not be null"); + this.usage = usage; + } + public static AzureOpenAiUsage from(ChatCompletions chatCompletions) { Assert.notNull(chatCompletions, "ChatCompletions must not be null"); return from(chatCompletions.getUsage()); @@ -39,13 +47,6 @@ public class AzureOpenAiUsage implements Usage { return new AzureOpenAiUsage(usage); } - private final CompletionsUsage usage; - - public AzureOpenAiUsage(CompletionsUsage usage) { - Assert.notNull(usage, "CompletionsUsage must not be null"); - this.usage = usage; - } - protected CompletionsUsage getUsage() { return this.usage; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index dbc6fa46d..e686d17fa 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,12 @@ package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIClient; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; -import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import org.junit.jupiter.api.Test; @@ -30,10 +32,6 @@ import org.mockito.Mockito; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -42,6 +40,11 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class AzureChatCompletionsOptionsTests { + private static Stream providePresencePenaltyAndFrequencyPenaltyTest() { + return Stream.of(Arguments.of(0.0, 0.0), Arguments.of(0.0, 1.0), Arguments.of(1.0, 0.0), Arguments.of(1.0, 1.0), + Arguments.of(1.0, null), Arguments.of(null, 1.0), Arguments.of(null, null)); + } + @Test public void createRequestWithChatOptions() { @@ -132,11 +135,6 @@ public class AzureChatCompletionsOptionsTests { assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class); } - private static Stream providePresencePenaltyAndFrequencyPenaltyTest() { - return Stream.of(Arguments.of(0.0, 0.0), Arguments.of(0.0, 1.0), Arguments.of(1.0, 0.0), Arguments.of(1.0, 1.0), - Arguments.of(1.0, null), Arguments.of(null, 1.0), Arguments.of(null, null)); - } - @ParameterizedTest @MethodSource("providePresencePenaltyAndFrequencyPenaltyTest") public void createChatOptionsWithPresencePenaltyAndFrequencyPenalty(Double presencePenalty, diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java index 18fe0e56a..627824385 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java index a8a7d44ae..e3fbcc92a 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModelIT.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; @@ -6,6 +22,7 @@ import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.beans.factory.annotation.Autowired; @@ -38,8 +55,9 @@ class AzureOpenAiAudioTranscriptionModelIT { .withResponseFormat(AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat.TEXT) .withTemperature(0f) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } @@ -54,8 +72,9 @@ class AzureOpenAiAudioTranscriptionModelIT { .withTemperature(0f) .withResponseFormat(responseFormat) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 4b931b799..8babb2e0c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,17 @@ package org.springframework.ai.azure.openai; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -35,13 +37,10 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; - -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.OpenAIServiceVersion; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.http.policy.HttpLogOptions; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; + +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -57,16 +56,13 @@ public class AzureOpenAiChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = chatClient.prompt() + ChatResponse response = this.chatClient.prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -84,7 +80,7 @@ public class AzureOpenAiChatClientIT { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = chatClient + Flux chatResponse = this.chatClient .prompt() .advisors(new SimpleLoggerAdvisor()) .user(u -> u @@ -117,12 +113,12 @@ public class AzureOpenAiChatClientIT { + "List them with a numerical index. Do not use any abbreviations in state or capitals."; // Imperative call - String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content(); + String rawDataFromImperativeCall = this.chatClient.prompt(prompt).call().content(); String imperativeStatesData = extractStatesData(rawDataFromImperativeCall); String formattedImperativeResponse = formatResponse(imperativeStatesData); // Streaming call - String stitchedResponseFromStream = chatClient.prompt(prompt) + String stitchedResponseFromStream = this.chatClient.prompt(prompt) .stream() .content() .collectList() @@ -150,6 +146,10 @@ public class AzureOpenAiChatClientIT { return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new)); } + record ActorsFilms(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index aaad145b5..ffa657aa9 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.core.credential.AzureKeyCredential; @@ -23,6 +32,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -44,14 +54,6 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import java.io.IOException; -import java.net.URL; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; import static org.assertj.core.api.Assertions.assertThat; @@ -77,7 +79,7 @@ class AzureOpenAiChatModelIT { UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates."); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -96,12 +98,12 @@ class AzureOpenAiChatModelIT { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); System.out.println(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); @@ -120,7 +122,7 @@ class AzureOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); List list = outputConverter.convert(generation.getOutput().getContent()); assertThat(list).hasSize(5); @@ -139,7 +141,7 @@ class AzureOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -158,7 +160,7 @@ class AzureOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isNotNull(); @@ -176,7 +178,7 @@ class AzureOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -197,7 +199,7 @@ class AzureOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -221,7 +223,7 @@ class AzureOpenAiChatModelIT { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() @@ -239,7 +241,7 @@ class AzureOpenAiChatModelIT { Resource resource = new ClassPathResource("multimodality/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, resource)) .call() @@ -252,9 +254,11 @@ class AzureOpenAiChatModelIT { } record ActorsFilms(String actor, List movies) { + } record ActorsFilmsRecord(String actor, List movies) { + } @SpringBootConfiguration diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java index 52a184eb6..2e194ea2d 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.azure.openai; import java.util.List; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -37,13 +42,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.OpenAIServiceVersion; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.http.policy.HttpLogOptions; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -54,14 +54,14 @@ import reactor.core.publisher.Flux; class AzureOpenAiChatModelObservationIT { @Autowired - private AzureOpenAiChatModel chatModel; + TestObservationRegistry observationRegistry; @Autowired - TestObservationRegistry observationRegistry; + private AzureOpenAiChatModel chatModel; @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -78,7 +78,7 @@ class AzureOpenAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -102,7 +102,7 @@ class AzureOpenAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); assertThat(responses).hasSizeGreaterThan(10); @@ -123,7 +123,7 @@ class AzureOpenAiChatModelObservationIT { private void validate(ChatResponseMetadata responseMetadata, boolean checkModel) { - TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java index 0ee62a147..c18114c67 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.List; @@ -22,6 +23,7 @@ import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; @@ -41,18 +43,18 @@ class AzureOpenAiEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - System.out.println(embeddingModel.dimensions()); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + System.out.println(this.embeddingModel.dimensions()); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -60,7 +62,7 @@ class AzureOpenAiEmbeddingModelIT { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @SpringBootConfiguration diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java index db7b05dfc..dc94e4a94 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.azure.openai; import java.util.List; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,12 +40,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.core.credential.AzureKeyCredential; - -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link AzureOpenAiEmbeddingModel}. @@ -69,13 +69,13 @@ public class AzureOpenAiEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java index 124cd4485..48df6e123 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.azure.openai; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.ai.azure.openai; import java.io.IOException; import java.io.UnsupportedEncodingException; @@ -26,8 +25,14 @@ import java.util.Optional; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedDeque; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okio.Buffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.InitializingBean; @@ -43,11 +48,7 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import okhttp3.mockwebserver.Dispatcher; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; -import okio.Buffer; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** * Spring {@link Configuration} for AI integration testing using mock objects. @@ -205,22 +206,22 @@ public class MockAiTestConfiguration { */ static class MockWebServerFactoryBean implements FactoryBean, InitializingBean, DisposableBean { - private Dispatcher dispatcher; - private final Logger logger = LoggerFactory.getLogger(getClass().getName()); - private MockWebServer mockWebServer; - private final Queue queuedResponses = new ConcurrentLinkedDeque<>(); - public void setDispatcher(@Nullable Dispatcher dispatcher) { - this.dispatcher = dispatcher; - } + private Dispatcher dispatcher; + + private MockWebServer mockWebServer; protected Optional getDispatcher() { return Optional.ofNullable(this.dispatcher); } + public void setDispatcher(@Nullable Dispatcher dispatcher) { + this.dispatcher = dispatcher; + } + protected Logger getLogger() { return this.logger; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index e4a12a846..1c0a84cad 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; +import okhttp3.HttpUrl; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockWebServer; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -24,10 +27,6 @@ import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Profile; import org.springframework.test.web.servlet.MockMvc; -import okhttp3.HttpUrl; -import okhttp3.mockwebserver.Dispatcher; -import okhttp3.mockwebserver.MockWebServer; - /** * {@link SpringBootConfiguration} for testing {@literal Azure OpenAI's} API using mock * objects. diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java index eaec2cbdd..8984fe5a3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.aot; import java.util.Set; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 635407cd7..19bc2c730 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.function; import java.util.ArrayList; @@ -22,21 +23,21 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; @@ -44,7 +45,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; @@ -69,7 +69,7 @@ class AzureOpenAiChatModelFunctionCallIT { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -77,7 +77,7 @@ class AzureOpenAiChatModelFunctionCallIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -93,7 +93,7 @@ class AzureOpenAiChatModelFunctionCallIT { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -101,7 +101,7 @@ class AzureOpenAiChatModelFunctionCallIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -115,7 +115,7 @@ class AzureOpenAiChatModelFunctionCallIT { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -123,7 +123,7 @@ class AzureOpenAiChatModelFunctionCallIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) @@ -152,7 +152,7 @@ class AzureOpenAiChatModelFunctionCallIT { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AzureOpenAiChatOptions.builder() - .withDeploymentName(selectedModel) + .withDeploymentName(this.selectedModel) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") @@ -160,7 +160,7 @@ class AzureOpenAiChatModelFunctionCallIT { .build())) .build(); - var response = chatModel.stream(new Prompt(messages, promptOptions)); + var response = this.chatModel.stream(new Prompt(messages, promptOptions)); final var counter = new AtomicInteger(); String content = response.doOnEach(listSignal -> counter.getAndIncrement()) @@ -182,6 +182,16 @@ class AzureOpenAiChatModelFunctionCallIT { @SpringBootConfiguration public static class TestConfiguration { + public static String getDeploymentName() { + String deploymentName = System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"); + if (StringUtils.hasText(deploymentName)) { + return deploymentName; + } + else { + return "gpt-4o"; + } + } + @Bean public OpenAIClientBuilder openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) @@ -199,16 +209,6 @@ class AzureOpenAiChatModelFunctionCallIT { return Optional.ofNullable(System.getenv("AZURE_OPENAI_MODEL")).orElse(getDeploymentName()); } - public static String getDeploymentName() { - String deploymentName = System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"); - if (StringUtils.hasText(deploymentName)) { - return deploymentName; - } - else { - return "gpt-4o"; - } - } - } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java index 92747ed30..e122e5f69 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,29 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai.function; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Christian Tzolov */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -63,28 +71,23 @@ public class MockWeatherService implements Function + + 4.0.0 diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 61389fffa..639409044 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock; import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; @@ -27,10 +28,6 @@ import org.springframework.util.Assert; */ public class BedrockUsage implements Usage { - public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { - return new BedrockUsage(usage); - } - private final AmazonBedrockInvocationMetrics usage; protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { @@ -38,6 +35,10 @@ public class BedrockUsage implements Usage { this.usage = usage; } + public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { + return new BedrockUsage(usage); + } + protected AmazonBedrockInvocationMetrics getUsage() { return this.usage; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java index 95abde872..001b2fd98 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock; import java.util.List; @@ -33,12 +34,12 @@ public class MessageToPromptConverter { private static final String ASSISTANT_PROMPT = "Assistant:"; + private final String lineSeparator; + private String humanPrompt = HUMAN_PROMPT; private String assistantPrompt = ASSISTANT_PROMPT; - private final String lineSeparator; - private MessageToPromptConverter(String lineSeparator) { this.lineSeparator = lineSeparator; } @@ -84,9 +85,9 @@ public class MessageToPromptConverter { case SYSTEM: return message.getContent(); case USER: - return humanPrompt + " " + message.getContent(); + return this.humanPrompt + " " + message.getContent(); case ASSISTANT: - return assistantPrompt + " " + message.getContent(); + return this.assistantPrompt + " " + message.getContent(); case TOOL: throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java index 7bdc15e2d..5625a55a1 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.util.List; @@ -20,11 +21,10 @@ import java.util.List; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -75,44 +75,14 @@ public class AnthropicChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private final AnthropicChatOptions options = new AnthropicChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withMaxTokensToSample(Integer maxTokensToSample) { - this.options.setMaxTokensToSample(maxTokensToSample); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); - return this; - } - - public AnthropicChatOptions build() { - return this.options; - } - + public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withMaxTokensToSample(fromOptions.getMaxTokensToSample()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withStopSequences(fromOptions.getStopSequences()) + .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .build(); } @Override @@ -201,14 +171,44 @@ public class AnthropicChatOptions implements ChatOptions { return fromOptions(this); } - public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withMaxTokensToSample(fromOptions.getMaxTokensToSample()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) - .build(); + public static class Builder { + + private final AnthropicChatOptions options = new AnthropicChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withMaxTokensToSample(Integer maxTokensToSample) { + this.options.setMaxTokensToSample(maxTokensToSample); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withAnthropicVersion(String anthropicVersion) { + this.options.setAnthropicVersion(anthropicVersion); + return this; + } + + public AnthropicChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java index c4321c374..f5f1f91be 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.util.List; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index c1235456b..074c03635 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic.api; import java.time.Duration; @@ -118,6 +119,54 @@ public class AnthropicChatBedrockApi extends // Anthropic Claude models: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html + @Override + public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); + } + + @Override + public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocationStream(anthropicRequest, AnthropicChatResponse.class); + } + + /** + * Anthropic models version. + */ + public enum AnthropicChatModel implements ChatModelDescription { + /** + * anthropic.claude-instant-v1 + */ + CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), + /** + * anthropic.claude-v2 + */ + CLAUDE_V2("anthropic.claude-v2"), + /** + * anthropic.claude-v2:1 + */ + CLAUDE_V21("anthropic.claude-v2:1"); + + private final String id; + + AnthropicChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * AnthropicChatRequest encapsulates the request parameters for the Anthropic chat model. * https://docs.anthropic.com/claude/reference/complete_post @@ -196,13 +245,13 @@ public class AnthropicChatBedrockApi extends public AnthropicChatRequest build() { return new AnthropicChatRequest( - prompt, - temperature, - maxTokensToSample, - topK, - topP, - stopSequences, - anthropicVersion + this.prompt, + this.temperature, + this.maxTokensToSample, + this.topK, + this.topP, + this.stopSequences, + this.anthropicVersion ); } } @@ -225,53 +274,5 @@ public class AnthropicChatBedrockApi extends @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ChatModelDescription { - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index 45927c911..86d8ad67b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; +import org.springframework.ai.chat.prompt.ChatOptions; /** * @author Ben Middleton @@ -74,44 +76,14 @@ public class Anthropic3ChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private final Anthropic3ChatOptions options = new Anthropic3ChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); - return this; - } - - public Anthropic3ChatOptions build() { - return this.options; - } - + public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withStopSequences(fromOptions.getStopSequences()) + .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .build(); } @Override @@ -190,14 +162,44 @@ public class Anthropic3ChatOptions implements ChatOptions { return fromOptions(this); } - public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) - .build(); + public static class Builder { + + private final Anthropic3ChatOptions options = new Anthropic3ChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withAnthropicVersion(String anthropicVersion) { + this.options.setAnthropicVersion(anthropicVersion); + return this; + } + + public Anthropic3ChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index 4bef201bb..deaa01f13 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; import java.util.ArrayList; @@ -21,11 +22,6 @@ import java.util.List; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.metadata.Usage; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; @@ -35,13 +31,18 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.An import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index d9f9cf672..e6e8b9611 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3.api; +import java.time.Duration; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.model.ChatModelDescription; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; /** * Based on Bedrock's chatCompletionStream(AnthropicChatRequest anthropicRequest) { + Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); + return this.internalInvocationStream(anthropicRequest, AnthropicChatStreamingResponse.class); + } + + /** + * Anthropic models version. + */ + public enum AnthropicChatModel implements ChatModelDescription { + + /** + * anthropic.claude-instant-v1 + */ + CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), + /** + * anthropic.claude-v2 + */ + CLAUDE_V2("anthropic.claude-v2"), + /** + * anthropic.claude-v2:1 + */ + CLAUDE_V21("anthropic.claude-v2:1"), + /** + * anthropic.claude-3-sonnet-20240229-v1:0 + */ + CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), + /** + * anthropic.claude-3-haiku-20240307-v1:0 + */ + CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), + /** + * anthropic.claude-3-opus-20240229-v1:0 + */ + CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"), + /** + * anthropic.claude-3-5-sonnet-20240620-v1:0 + */ + CLAUDE_V3_5_SONNET("anthropic.claude-3-5-sonnet-20240620-v1:0"), + /** + * anthropic.claude-3-5-sonnet-20241022-v2:0 + */ + CLAUDE_V3_5_SONNET_V2("anthropic.claude-3-5-sonnet-20241022-v2:0"); + + private final String id; + + AnthropicChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + + } + /** * AnthropicChatRequest encapsulates the request parameters for the Anthropic messages model. * https://docs.anthropic.com/claude/reference/messages_post @@ -208,14 +280,14 @@ public class Anthropic3ChatBedrockApi extends public AnthropicChatRequest build() { return new AnthropicChatRequest( - messages, - system, - temperature, - maxTokens, - topK, - topP, - stopSequences, - anthropicVersion + this.messages, + this.system, + this.temperature, + this.maxTokens, + this.topK, + this.topP, + this.stopSequences, + this.anthropicVersion ); } } @@ -286,7 +358,9 @@ public class Anthropic3ChatBedrockApi extends public Source(String mediaType, String data) { this("base64", mediaType, data); } + } + } /** @@ -317,6 +391,7 @@ public class Anthropic3ChatBedrockApi extends ASSISTANT } + } /** @@ -329,6 +404,7 @@ public class Anthropic3ChatBedrockApi extends @JsonInclude(Include.NON_NULL) public record AnthropicUsage(@JsonProperty("input_tokens") Integer inputTokens, @JsonProperty("output_tokens") Integer outputTokens) { + } /** @@ -356,6 +432,7 @@ public class Anthropic3ChatBedrockApi extends @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence, @JsonProperty("usage") AnthropicUsage usage, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { // formatter:on + } /** @@ -432,77 +509,9 @@ public class Anthropic3ChatBedrockApi extends @JsonInclude(Include.NON_NULL) public record Delta(@JsonProperty("type") String type, @JsonProperty("text") String text, @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence) { - } - } - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ChatModelDescription { - - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"), - /** - * anthropic.claude-3-sonnet-20240229-v1:0 - */ - CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), - /** - * anthropic.claude-3-haiku-20240307-v1:0 - */ - CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), - /** - * anthropic.claude-3-opus-20240229-v1:0 - */ - CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"), - /** - * anthropic.claude-3-5-sonnet-20240620-v1:0 - */ - CLAUDE_V3_5_SONNET("anthropic.claude-3-5-sonnet-20240620-v1:0"), - /** - * anthropic.claude-3-5-sonnet-20241022-v2:0 - */ - CLAUDE_V3_5_SONNET_V2("anthropic.claude-3-5-sonnet-20241022-v2:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; } - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatStreamingResponse.class); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index 7db24b3b8..b6f93d3b8 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.aot; import org.springframework.ai.bedrock.anthropic.AnthropicChatOptions; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 24a383ada..11897200c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -175,24 +175,6 @@ public abstract class AbstractBedrockApi { return this.region; } - /** - * Encapsulates the metrics about the model invocation. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html - * - * @param inputTokenCount The number of tokens in the input prompt. - * @param firstByteLatency The time in milliseconds between the request being sent and the first byte of the - * response being received. - * @param outputTokenCount The number of tokens in the generated text. - * @param invocationLatency The time in milliseconds between the request being sent and the response being received. - */ - @JsonInclude(Include.NON_NULL) - public record AmazonBedrockInvocationMetrics( - @JsonProperty("inputTokenCount") Long inputTokenCount, - @JsonProperty("firstByteLatency") Long firstByteLatency, - @JsonProperty("outputTokenCount") Long outputTokenCount, - @JsonProperty("invocationLatency") Long invocationLatency) { - } - /** * Compute the embedding for the given text. * @@ -337,5 +319,23 @@ public abstract class AbstractBedrockApi { return eventSink.asFlux(); } + + /** + * Encapsulates the metrics about the model invocation. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html + * + * @param inputTokenCount The number of tokens in the input prompt. + * @param firstByteLatency The time in milliseconds between the request being sent and the first byte of the + * response being received. + * @param outputTokenCount The number of tokens in the generated text. + * @param invocationLatency The time in milliseconds between the request being sent and the response being received. + */ + @JsonInclude(Include.NON_NULL) + public record AmazonBedrockInvocationMetrics( + @JsonProperty("inputTokenCount") Long inputTokenCount, + @JsonProperty("firstByteLatency") Long firstByteLatency, + @JsonProperty("outputTokenCount") Long outputTokenCount, + @JsonProperty("invocationLatency") Long invocationLatency) { + } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index e9895fc1d..4d235a2b2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; @@ -24,13 +25,13 @@ import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java index 04f67f282..8e5e0a689 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; @@ -85,59 +86,17 @@ public class BedrockCohereChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private final BedrockCohereChatOptions options = new BedrockCohereChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withTopK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) { - this.options.setReturnLikelihoods(returnLikelihoods); - return this; - } - - public Builder withNumGenerations(Integer numGenerations) { - this.options.setNumGenerations(numGenerations); - return this; - } - - public Builder withLogitBias(LogitBias logitBias) { - this.options.setLogitBias(logitBias); - return this; - } - - public Builder withTruncate(Truncate truncate) { - this.options.setTruncate(truncate); - return this; - } - - public BedrockCohereChatOptions build() { - return this.options; - } - + public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withReturnLikelihoods(fromOptions.getReturnLikelihoods()) + .withNumGenerations(fromOptions.getNumGenerations()) + .withLogitBias(fromOptions.getLogitBias()) + .withTruncate(fromOptions.getTruncate()) + .build(); } @Override @@ -240,17 +199,59 @@ public class BedrockCohereChatOptions implements ChatOptions { return fromOptions(this); } - public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withStopSequences(fromOptions.getStopSequences()) - .withReturnLikelihoods(fromOptions.getReturnLikelihoods()) - .withNumGenerations(fromOptions.getNumGenerations()) - .withLogitBias(fromOptions.getLogitBias()) - .withTruncate(fromOptions.getTruncate()) - .build(); + public static class Builder { + + private final BedrockCohereChatOptions options = new BedrockCohereChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) { + this.options.setReturnLikelihoods(returnLikelihoods); + return this; + } + + public Builder withNumGenerations(Integer numGenerations) { + this.options.setNumGenerations(numGenerations); + return this; + } + + public Builder withLogitBias(LogitBias logitBias) { + this.options.setLogitBias(logitBias); + return this; + } + + public Builder withTruncate(Truncate truncate) { + this.options.setTruncate(truncate); + return this; + } + + public BedrockCohereChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index c070f2506..c34335d8b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.util.List; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java index 068d70454..57e0de302 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -52,26 +53,6 @@ public class BedrockCohereEmbeddingOptions implements EmbeddingOptions { return new Builder(); } - public static class Builder { - - private BedrockCohereEmbeddingOptions options = new BedrockCohereEmbeddingOptions(); - - public Builder withInputType(InputType inputType) { - this.options.setInputType(inputType); - return this; - } - - public Builder withTruncate(Truncate truncate) { - this.options.setTruncate(truncate); - return this; - } - - public BedrockCohereEmbeddingOptions build() { - return this.options; - } - - } - public InputType getInputType() { return this.inputType; } @@ -100,4 +81,24 @@ public class BedrockCohereEmbeddingOptions implements EmbeddingOptions { return null; } + public static class Builder { + + private BedrockCohereEmbeddingOptions options = new BedrockCohereEmbeddingOptions(); + + public Builder withInputType(InputType inputType) { + this.options.setInputType(inputType); + return this; + } + + public Builder withTruncate(Truncate truncate) { + this.options.setTruncate(truncate); + return this; + } + + public BedrockCohereEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index 9feb62eeb..454bd8aed 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -109,6 +109,52 @@ public class CohereChatBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public CohereChatResponse chatCompletion(CohereChatRequest request) { + Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); + return this.internalInvocation(request, CohereChatResponse.class); + } + + @Override + public Flux chatCompletionStream(CohereChatRequest request) { + Assert.isTrue(request.stream(), "The request must be configured to stream the response!"); + return this.internalInvocationStream(request, CohereChatResponse.Generation.class); + } + + /** + * Cohere models version. + */ + public enum CohereChatModel implements ChatModelDescription { + + /** + * cohere.command-light-text-v14 + */ + COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), + + /** + * cohere.command-text-v14 + */ + COHERE_COMMAND_V14("cohere.command-text-v14"); + + private final String id; + + CohereChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * CohereChatRequest encapsulates the request parameters for the Cohere command model. * @@ -143,15 +189,12 @@ public class CohereChatBedrockApi extends @JsonProperty("truncate") Truncate truncate) { /** - * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. - * - * @param token The token likelihoods. - * @param bias A float between -10 and 10. + * Get CohereChatRequest builder. + * @param prompt compulsory request prompt parameter. + * @return CohereChatRequest builder. */ - @JsonInclude(Include.NON_NULL) - public record LogitBias( - @JsonProperty("token") String token, - @JsonProperty("bias") Float bias) { + public static Builder builder(String prompt) { + return new Builder(prompt); } /** @@ -192,12 +235,15 @@ public class CohereChatBedrockApi extends } /** - * Get CohereChatRequest builder. - * @param prompt compulsory request prompt parameter. - * @return CohereChatRequest builder. + * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. + * + * @param token The token likelihoods. + * @param bias A float between -10 and 10. */ - public static Builder builder(String prompt) { - return new Builder(prompt); + @JsonInclude(Include.NON_NULL) + public record LogitBias( + @JsonProperty("token") String token, + @JsonProperty("bias") Float bias) { } /** @@ -272,17 +318,17 @@ public class CohereChatBedrockApi extends public CohereChatRequest build() { return new CohereChatRequest( - prompt, - temperature, - topP, - topK, - maxTokens, - stopSequences, - returnLikelihoods, - stream, - numGenerations, - logitBias, - truncate + this.prompt, + this.temperature, + this.topP, + this.topK, + this.maxTokens, + this.stopSequences, + this.returnLikelihoods, + this.stream, + this.numGenerations, + this.logitBias, + this.truncate ); } } @@ -331,16 +377,6 @@ public class CohereChatBedrockApi extends @JsonProperty("index") Integer index, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - /** - * @param token The token. - * @param likelihood The likelihood of the token. - */ - @JsonInclude(Include.NON_NULL) - public record TokenLikelihood( - @JsonProperty("token") String token, - @JsonProperty("likelihood") Float likelihood) { - } - /** * The reason the response finished being generated. */ @@ -363,53 +399,17 @@ public class CohereChatBedrockApi extends */ ERROR_TOXIC } + + /** + * @param token The token. + * @param likelihood The likelihood of the token. + */ + @JsonInclude(Include.NON_NULL) + public record TokenLikelihood( + @JsonProperty("token") String token, + @JsonProperty("likelihood") Float likelihood) { + } } } - - /** - * Cohere models version. - */ - public enum CohereChatModel implements ChatModelDescription { - - /** - * cohere.command-light-text-v14 - */ - COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), - - /** - * cohere.command-text-v14 - */ - COHERE_COMMAND_V14("cohere.command-text-v14"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - CohereChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public CohereChatResponse chatCompletion(CohereChatRequest request) { - Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); - return this.internalInvocation(request, CohereChatResponse.class); - } - - @Override - public Flux chatCompletionStream(CohereChatRequest request) { - Assert.isTrue(request.stream(), "The request must be configured to stream the response!"); - return this.internalInvocationStream(request, CohereChatResponse.Generation.class); - } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index 30938fe9a..e69f229d6 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -109,6 +109,39 @@ public class CohereEmbeddingBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) { + return this.internalInvocation(request, CohereEmbeddingResponse.class); + } + + /** + * Cohere Embedding model ids. https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html + */ + public enum CohereEmbeddingModel { + /** + * cohere.embed-multilingual-v3 + */ + COHERE_EMBED_MULTILINGUAL_V1("cohere.embed-multilingual-v3"), + /** + * cohere.embed-english-v3 + */ + COHERE_EMBED_ENGLISH_V3("cohere.embed-english-v3"); + + private final String id; + + CohereEmbeddingModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + } + /** * The Cohere Embed model request. * @@ -190,38 +223,5 @@ public class CohereEmbeddingBedrockApi extends @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - /** - * Cohere Embedding model ids. https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html - */ - public enum CohereEmbeddingModel { - /** - * cohere.embed-multilingual-v3 - */ - COHERE_EMBED_MULTILINGUAL_V1("cohere.embed-multilingual-v3"), - /** - * cohere.embed-english-v3 - */ - COHERE_EMBED_ENGLISH_V3("cohere.embed-english-v3"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return this.id; - } - - CohereEmbeddingModel(String value) { - this.id = value; - } - - } - - @Override - public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) { - return this.internalInvocation(request, CohereEmbeddingResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java index 7cb9bf0ac..ab463b911 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,10 +19,10 @@ package org.springframework.ai.bedrock.jurassic2; import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -57,6 +57,10 @@ public class BedrockAi21Jurassic2ChatModel implements ChatModel { .build()); } + public static Builder builder(Ai21Jurassic2ChatBedrockApi chatApi) { + return new Builder(chatApi); + } + @Override public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); @@ -88,8 +92,9 @@ public class BedrockAi21Jurassic2ChatModel implements ChatModel { return request; } - public static Builder builder(Ai21Jurassic2ChatBedrockApi chatApi) { - return new Builder(chatApi); + @Override + public ChatOptions getDefaultOptions() { + return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); } public static class Builder { @@ -108,15 +113,10 @@ public class BedrockAi21Jurassic2ChatModel implements ChatModel { } public BedrockAi21Jurassic2ChatModel build() { - return new BedrockAi21Jurassic2ChatModel(chatApi, - options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build()); + return new BedrockAi21Jurassic2ChatModel(this.chatApi, + this.options != null ? this.options : BedrockAi21Jurassic2ChatOptions.builder().build()); } } - @Override - public ChatOptions getDefaultOptions() { - return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); - } - } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java index eb8ce968a..aa292edfa 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,13 @@ package org.springframework.ai.bedrock.jurassic2; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; +import org.springframework.ai.chat.prompt.ChatOptions; /** * Request body for the /complete endpoint of the Jurassic-2 API. @@ -101,12 +102,31 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { // Getters and setters + public static Builder builder() { + return new Builder(); + } + + public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) { + return builder().withPrompt(fromOptions.getPrompt()) + .withNumResults(fromOptions.getNumResults()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withMinTokens(fromOptions.getMinTokens()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withStopSequences(fromOptions.getStopSequences()) + .withFrequencyPenaltyOptions(fromOptions.getFrequencyPenaltyOptions()) + .withPresencePenaltyOptions(fromOptions.getPresencePenaltyOptions()) + .withCountPenaltyOptions(fromOptions.getCountPenaltyOptions()) + .build(); + } + /** * Gets the prompt text for the model to continue. * @return The prompt text. */ public String getPrompt() { - return prompt; + return this.prompt; } /** @@ -122,7 +142,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * @return The number of results. */ public Integer getNumResults() { - return numResults; + return this.numResults; } /** @@ -139,7 +159,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { */ @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } /** @@ -155,7 +175,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * @return The minimum number of tokens. */ public Integer getMinTokens() { - return minTokens; + return this.minTokens; } /** @@ -172,7 +192,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { */ @Override public Double getTemperature() { - return temperature; + return this.temperature; } /** @@ -190,7 +210,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { */ @Override public Double getTopP() { - return topP; + return this.topP; } /** @@ -208,7 +228,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { */ @Override public Integer getTopK() { - return topK; + return this.topK; } /** @@ -225,7 +245,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { */ @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } /** @@ -254,7 +274,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * @return The frequency penalty object. */ public Penalty getFrequencyPenaltyOptions() { - return frequencyPenaltyOptions; + return this.frequencyPenaltyOptions; } /** @@ -283,7 +303,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * @return The presence penalty object. */ public Penalty getPresencePenaltyOptions() { - return presencePenaltyOptions; + return this.presencePenaltyOptions; } /** @@ -299,7 +319,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * @return The count penalty object. */ public Penalty getCountPenaltyOptions() { - return countPenaltyOptions; + return this.countPenaltyOptions; } /** @@ -316,8 +336,9 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { return null; } - public static Builder builder() { - return new Builder(); + @Override + public BedrockAi21Jurassic2ChatOptions copy() { + return fromOptions(this); } public static class Builder { @@ -325,62 +346,62 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { private final BedrockAi21Jurassic2ChatOptions request = new BedrockAi21Jurassic2ChatOptions(); public Builder withPrompt(String prompt) { - request.setPrompt(prompt); + this.request.setPrompt(prompt); return this; } public Builder withNumResults(Integer numResults) { - request.setNumResults(numResults); + this.request.setNumResults(numResults); return this; } public Builder withMaxTokens(Integer maxTokens) { - request.setMaxTokens(maxTokens); + this.request.setMaxTokens(maxTokens); return this; } public Builder withMinTokens(Integer minTokens) { - request.setMinTokens(minTokens); + this.request.setMinTokens(minTokens); return this; } public Builder withTemperature(Double temperature) { - request.setTemperature(temperature); + this.request.setTemperature(temperature); return this; } public Builder withTopP(Double topP) { - request.setTopP(topP); + this.request.setTopP(topP); return this; } public Builder withStopSequences(List stopSequences) { - request.setStopSequences(stopSequences); + this.request.setStopSequences(stopSequences); return this; } public Builder withTopK(Integer topKReturn) { - request.setTopK(topKReturn); + this.request.setTopK(topKReturn); return this; } public Builder withFrequencyPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty frequencyPenalty) { - request.setFrequencyPenaltyOptions(frequencyPenalty); + this.request.setFrequencyPenaltyOptions(frequencyPenalty); return this; } public Builder withPresencePenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty presencePenalty) { - request.setPresencePenaltyOptions(presencePenalty); + this.request.setPresencePenaltyOptions(presencePenalty); return this; } public Builder withCountPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty countPenalty) { - request.setCountPenaltyOptions(countPenalty); + this.request.setCountPenaltyOptions(countPenalty); return this; } public BedrockAi21Jurassic2ChatOptions build() { - return request; + return this.request; } } @@ -446,31 +467,12 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { } public Penalty build() { - return new Penalty(scale, applyToNumbers, applyToPunctuations, applyToStopwords, applyToWhitespaces, - applyToEmojis); + return new Penalty(this.scale, this.applyToNumbers, this.applyToPunctuations, this.applyToStopwords, + this.applyToWhitespaces, this.applyToEmojis); } } - } - @Override - public BedrockAi21Jurassic2ChatOptions copy() { - return fromOptions(this); - } - - public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) { - return builder().withPrompt(fromOptions.getPrompt()) - .withNumResults(fromOptions.getNumResults()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMinTokens(fromOptions.getMinTokens()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withStopSequences(fromOptions.getStopSequences()) - .withFrequencyPenaltyOptions(fromOptions.getFrequencyPenaltyOptions()) - .withPresencePenaltyOptions(fromOptions.getPresencePenaltyOptions()) - .withCountPenaltyOptions(fromOptions.getCountPenaltyOptions()) - .build(); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index c3ab019d1..9fa7104cf 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,16 +22,15 @@ import java.util.List; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; - import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; import org.springframework.ai.model.ChatModelDescription; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - /** * Java client for the Bedrock Jurassic2 chat model. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html @@ -110,6 +109,45 @@ public class Ai21Jurassic2ChatBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { + return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); + } + + /** + * Ai21 Jurassic2 models version. + */ + public enum Ai21Jurassic2ChatModel implements ChatModelDescription { + + /** + * ai21.j2-mid-v1 + */ + AI21_J2_MID_V1("ai21.j2-mid-v1"), + + /** + * ai21.j2-ultra-v1 + */ + AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); + + private final String id; + + Ai21Jurassic2ChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * AI21 Jurassic2 chat request parameters. * @@ -141,6 +179,10 @@ public class Ai21Jurassic2ChatBedrockApi extends @JsonProperty("presencePenalty") FloatScalePenalty presencePenalty, @JsonProperty("frequencyPenalty") IntegerScalePenalty frequencyPenalty) { + public static Builder builder(String prompt) { + return new Builder(prompt); + } + /** * Penalty with integer scale value. * @@ -192,11 +234,6 @@ public class Ai21Jurassic2ChatBedrockApi extends @JsonProperty("applyToEmojis") boolean applyToEmojis) { } - - - public static Builder builder(String prompt) { - return new Builder(prompt); - } public static class Builder { private String prompt; private Double temperature; @@ -248,14 +285,14 @@ public class Ai21Jurassic2ChatBedrockApi extends public Ai21Jurassic2ChatRequest build() { return new Ai21Jurassic2ChatRequest( - prompt, - temperature, - topP, - maxTokens, - stopSequences, - countPenalty, - presencePenalty, - frequencyPenalty + this.prompt, + this.temperature, + this.topP, + this.maxTokens, + this.stopSequences, + this.countPenalty, + this.presencePenalty, + this.frequencyPenalty ); } } @@ -370,45 +407,6 @@ public class Ai21Jurassic2ChatBedrockApi extends } } - /** - * Ai21 Jurassic2 models version. - */ - public enum Ai21Jurassic2ChatModel implements ChatModelDescription { - - /** - * ai21.j2-mid-v1 - */ - AI21_J2_MID_V1("ai21.j2-mid-v1"), - - /** - * ai21.j2-ultra-v1 - */ - AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - Ai21Jurassic2ChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { - return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index 51b83a7be..1944b85ab 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; import java.util.List; @@ -23,13 +24,13 @@ import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java index ed50bd3c5..bdeb7543a 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -22,8 +25,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -52,29 +53,11 @@ public class BedrockLlamaChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withMaxGenLen(Integer maxGenLen) { - this.options.setMaxGenLen(maxGenLen); - return this; - } - - public BedrockLlamaChatOptions build() { - return this.options; - } - + public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaxGenLen(fromOptions.getMaxGenLen()) + .build(); } @Override @@ -149,11 +132,29 @@ public class BedrockLlamaChatOptions implements ChatOptions { return fromOptions(this); } - public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaxGenLen(fromOptions.getMaxGenLen()) - .build(); + public static class Builder { + + private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withMaxGenLen(Integer maxGenLen) { + this.options.setMaxGenLen(maxGenLen); + return this; + } + + public BedrockLlamaChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java index a476f70cf..4a76ee485 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama.api; +import java.time.Duration; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -28,8 +31,6 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatReq import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; import org.springframework.ai.model.ChatModelDescription; -import java.time.Duration; - // @formatter:off /** * Java client for the Bedrock Llama chat model. @@ -109,100 +110,14 @@ public class LlamaChatBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } - /** - * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model. - * - * @param prompt The prompt to use for the chat. - * @param temperature The temperature value controls the randomness of the generated text. Use a lower value to - * decrease randomness in the response. - * @param topP The topP value controls the diversity of the generated text. Use a lower value to ignore less - * probable options. Set to 0 or 1.0 to disable. - * @param maxGenLen The maximum length of the generated text. - */ - @JsonInclude(Include.NON_NULL) - public record LlamaChatRequest( - @JsonProperty("prompt") String prompt, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("max_gen_len") Integer maxGenLen) { - - /** - * Create a new LlamaChatRequest builder. - * @param prompt compulsory prompt parameter. - * @return a new LlamaChatRequest builder. - */ - public static Builder builder(String prompt) { - return new Builder(prompt); - } - - public static class Builder { - private String prompt; - private Double temperature; - private Double topP; - private Integer maxGenLen; - - public Builder(String prompt) { - this.prompt = prompt; - } - - public Builder withTemperature(Double temperature) { - this.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.topP = topP; - return this; - } - - public Builder withMaxGenLen(Integer maxGenLen) { - this.maxGenLen = maxGenLen; - return this; - } - - public LlamaChatRequest build() { - return new LlamaChatRequest( - prompt, - temperature, - topP, - maxGenLen - ); - } - } + @Override + public LlamaChatResponse chatCompletion(LlamaChatRequest request) { + return this.internalInvocation(request, LlamaChatResponse.class); } - /** - * LlamaChatResponse encapsulates the response parameters for the Meta Llama chat model. - * - * @param generation The generated text. - * @param promptTokenCount The number of tokens in the prompt. - * @param generationTokenCount The number of tokens in the response. - * @param stopReason The reason why the response stopped generating text. Possible values are: (1) stop – The model - * has finished generating text for the input prompt. (2) length – The length of the tokens for the generated text - * exceeds the value of max_gen_len in the call. The response is truncated to max_gen_len tokens. Consider - * increasing the value of max_gen_len and trying again. - */ - @JsonInclude(Include.NON_NULL) - public record LlamaChatResponse( - @JsonProperty("generation") String generation, - @JsonProperty("prompt_token_count") Integer promptTokenCount, - @JsonProperty("generation_token_count") Integer generationTokenCount, - @JsonProperty("stop_reason") StopReason stopReason, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - - /** - * The reason the response finished being generated. - */ - public enum StopReason { - /** - * The model has finished generating text for the input prompt. - */ - @JsonProperty("stop") STOP, - /** - * The response was truncated because of the response length you set. - */ - @JsonProperty("length") LENGTH - } + @Override + public Flux chatCompletionStream(LlamaChatRequest request) { + return this.internalInvocationStream(request, LlamaChatResponse.class); } /** @@ -267,15 +182,15 @@ public class LlamaChatBedrockApi extends private final String id; + LlamaChatModel(String value) { + this.id = value; + } + /** * @return The model id. */ public String id() { - return id; - } - - LlamaChatModel(String value) { - this.id = value; + return this.id; } @Override @@ -284,14 +199,100 @@ public class LlamaChatBedrockApi extends } } - @Override - public LlamaChatResponse chatCompletion(LlamaChatRequest request) { - return this.internalInvocation(request, LlamaChatResponse.class); + /** + * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model. + * + * @param prompt The prompt to use for the chat. + * @param temperature The temperature value controls the randomness of the generated text. Use a lower value to + * decrease randomness in the response. + * @param topP The topP value controls the diversity of the generated text. Use a lower value to ignore less + * probable options. Set to 0 or 1.0 to disable. + * @param maxGenLen The maximum length of the generated text. + */ + @JsonInclude(Include.NON_NULL) + public record LlamaChatRequest( + @JsonProperty("prompt") String prompt, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("max_gen_len") Integer maxGenLen) { + + /** + * Create a new LlamaChatRequest builder. + * @param prompt compulsory prompt parameter. + * @return a new LlamaChatRequest builder. + */ + public static Builder builder(String prompt) { + return new Builder(prompt); + } + + public static class Builder { + private String prompt; + private Double temperature; + private Double topP; + private Integer maxGenLen; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder withTemperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.topP = topP; + return this; + } + + public Builder withMaxGenLen(Integer maxGenLen) { + this.maxGenLen = maxGenLen; + return this; + } + + public LlamaChatRequest build() { + return new LlamaChatRequest( + this.prompt, + this.temperature, + this.topP, + this.maxGenLen + ); + } + } } - @Override - public Flux chatCompletionStream(LlamaChatRequest request) { - return this.internalInvocationStream(request, LlamaChatResponse.class); + /** + * LlamaChatResponse encapsulates the response parameters for the Meta Llama chat model. + * + * @param generation The generated text. + * @param promptTokenCount The number of tokens in the prompt. + * @param generationTokenCount The number of tokens in the response. + * @param stopReason The reason why the response stopped generating text. Possible values are: (1) stop – The model + * has finished generating text for the input prompt. (2) length – The length of the tokens for the generated text + * exceeds the value of max_gen_len in the call. The response is truncated to max_gen_len tokens. Consider + * increasing the value of max_gen_len and trying again. + */ + @JsonInclude(Include.NON_NULL) + public record LlamaChatResponse( + @JsonProperty("generation") String generation, + @JsonProperty("prompt_token_count") Integer promptTokenCount, + @JsonProperty("generation_token_count") Integer generationTokenCount, + @JsonProperty("stop_reason") StopReason stopReason, + @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { + + /** + * The reason the response finished being generated. + */ + public enum StopReason { + /** + * The model has finished generating text for the input prompt. + */ + @JsonProperty("stop") STOP, + /** + * The response was truncated because of the response length you set. + */ + @JsonProperty("length") LENGTH + } } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index e6d55a03b..1003ef044 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.List; @@ -24,13 +25,13 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.Assert; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java index d1187f118..a5d06bdba 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.List; @@ -20,11 +21,10 @@ import java.util.List; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -59,39 +59,17 @@ public class BedrockTitanChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private BedrockTitanChatOptions options = new BedrockTitanChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withMaxTokenCount(Integer maxTokenCount) { - this.options.maxTokenCount = maxTokenCount; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; - return this; - } - - public BedrockTitanChatOptions build() { - return this.options; - } - + public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaxTokenCount(fromOptions.getMaxTokenCount()) + .withStopSequences(fromOptions.getStopSequences()) + .build(); } @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -100,7 +78,7 @@ public class BedrockTitanChatOptions implements ChatOptions { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -119,7 +97,7 @@ public class BedrockTitanChatOptions implements ChatOptions { } public Integer getMaxTokenCount() { - return maxTokenCount; + return this.maxTokenCount; } public void setMaxTokenCount(Integer maxTokenCount) { @@ -128,7 +106,7 @@ public class BedrockTitanChatOptions implements ChatOptions { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -164,12 +142,34 @@ public class BedrockTitanChatOptions implements ChatOptions { return fromOptions(this); } - public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) { - return builder().withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaxTokenCount(fromOptions.getMaxTokenCount()) - .withStopSequences(fromOptions.getStopSequences()) - .build(); + public static class Builder { + + private BedrockTitanChatOptions options = new BedrockTitanChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withMaxTokenCount(Integer maxTokenCount) { + this.options.maxTokenCount = maxTokenCount; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public BedrockTitanChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index 84a646b84..c07527b0e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.util.ArrayList; @@ -50,12 +51,6 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { private final TitanEmbeddingBedrockApi embeddingApi; - public enum InputType { - - TEXT, IMAGE - - } - /** * Titan Embedding API input types. Could be either text or image (encoded in base64). */ @@ -83,7 +78,7 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); if (request.getInstructions().size() != 1) { - logger.warn( + this.logger.warn( "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } @@ -113,7 +108,7 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { public int dimensions() { if (this.inputType == InputType.IMAGE) { if (this.embeddingDimensions.get() < 0) { - this.embeddingDimensions.set(dimensions(this, embeddingApi.getModelId(), + this.embeddingDimensions.set(dimensions(this, this.embeddingApi.getModelId(), // small base64 encoded image "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")); } @@ -122,4 +117,10 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { } + public enum InputType { + + TEXT, IMAGE + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java index 28757f3b7..61b82dbdf 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -39,23 +40,6 @@ public class BedrockTitanEmbeddingOptions implements EmbeddingOptions { return new Builder(); } - public static class Builder { - - private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions(); - - public Builder withInputType(InputType inputType) { - Assert.notNull(inputType, "input type can not be null."); - - this.options.setInputType(inputType); - return this; - } - - public BedrockTitanEmbeddingOptions build() { - return this.options; - } - - } - public InputType getInputType() { return this.inputType; } @@ -76,4 +60,21 @@ public class BedrockTitanEmbeddingOptions implements EmbeddingOptions { return null; } + public static class Builder { + + private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions(); + + public Builder withInputType(InputType inputType) { + Assert.notNull(inputType, "input type can not be null."); + + this.options.setInputType(inputType); + return this; + } + + public BedrockTitanEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index 85a1f10c7..19e76729d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; @@ -110,6 +111,55 @@ public class TitanChatBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public TitanChatResponse chatCompletion(TitanChatRequest request) { + return this.internalInvocation(request, TitanChatResponse.class); + } + + @Override + public Flux chatCompletionStream(TitanChatRequest request) { + return this.internalInvocationStream(request, TitanChatResponseChunk.class); + } + + /** + * Titan models version. + */ + public enum TitanChatModel implements ChatModelDescription { + + /** + * amazon.titan-text-lite-v1 + */ + TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), + + /** + * amazon.titan-text-express-v1 + */ + TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), + + /** + * amazon.titan-text-premier-v1:0 + */ + TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); + + private final String id; + + TitanChatModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + + @Override + public String getName() { + return this.id; + } + } + /** * TitanChatRequest encapsulates the request parameters for the Titan chat model. * @@ -121,6 +171,15 @@ public class TitanChatBedrockApi extends @JsonProperty("inputText") String inputText, @JsonProperty("textGenerationConfig") TextGenerationConfig textGenerationConfig) { + /** + * Create a new TitanChatRequest builder. + * @param inputText The prompt to use for the chat. + * @return A new TitanChatRequest builder. + */ + public static Builder builder(String inputText) { + return new Builder(inputText); + } + /** * Titan request text generation configuration. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html @@ -141,15 +200,6 @@ public class TitanChatBedrockApi extends @JsonProperty("stopSequences") List stopSequences) { } - /** - * Create a new TitanChatRequest builder. - * @param inputText The prompt to use for the chat. - * @return A new TitanChatRequest builder. - */ - public static Builder builder(String inputText) { - return new Builder(inputText); - } - public static class Builder { private final String inputText; private Double temperature; @@ -210,20 +260,6 @@ public class TitanChatBedrockApi extends @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, @JsonProperty("results") List results) { - /** - * Titan response result. - * - * @param tokenCount The number of tokens in the generated text. - * @param outputText The generated text. - * @param completionReason The reason the response finished being generated. - */ - @JsonInclude(Include.NON_NULL) - public record Result( - @JsonProperty("tokenCount") Integer tokenCount, - @JsonProperty("outputText") String outputText, - @JsonProperty("completionReason") CompletionReason completionReason) { - } - /** * The reason the response finished being generated. */ @@ -243,6 +279,20 @@ public class TitanChatBedrockApi extends */ CONTENT_FILTERED } + + /** + * Titan response result. + * + * @param tokenCount The number of tokens in the generated text. + * @param outputText The generated text. + * @param completionReason The reason the response finished being generated. + */ + @JsonInclude(Include.NON_NULL) + public record Result( + @JsonProperty("tokenCount") Integer tokenCount, + @JsonProperty("outputText") String outputText, + @JsonProperty("completionReason") CompletionReason completionReason) { + } } /** @@ -263,54 +313,5 @@ public class TitanChatBedrockApi extends @JsonProperty("completionReason") CompletionReason completionReason, @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { } - - /** - * Titan models version. - */ - public enum TitanChatModel implements ChatModelDescription { - - /** - * amazon.titan-text-lite-v1 - */ - TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), - - /** - * amazon.titan-text-express-v1 - */ - TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), - - /** - * amazon.titan-text-premier-v1:0 - */ - TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - TitanChatModel(String value) { - this.id = value; - } - - @Override - public String getName() { - return this.id; - } - } - - @Override - public TitanChatResponse chatCompletion(TitanChatRequest request) { - return this.internalInvocation(request, TitanChatResponse.class); - } - - @Override - public Flux chatCompletionStream(TitanChatRequest request) { - return this.internalInvocationStream(request, TitanChatResponseChunk.class); - } } // @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 01968c81c..b94ccff9a 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; -import java.util.List; import java.util.Map; import com.fasterxml.jackson.annotation.JsonInclude; @@ -27,7 +27,6 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; import org.springframework.util.Assert; @@ -83,6 +82,42 @@ public class TitanEmbeddingBedrockApi extends super(modelId, credentialsProvider, region, objectMapper, timeout); } + @Override + public TitanEmbeddingResponse embedding(TitanEmbeddingRequest request) { + return this.internalInvocation(request, TitanEmbeddingResponse.class); + } + + /** + * Titan Embedding model ids. + */ + public enum TitanEmbeddingModel { + /** + * amazon.titan-embed-image-v1 + */ + TITAN_EMBED_IMAGE_V1("amazon.titan-embed-image-v1"), + /** + * amazon.titan-embed-text-v1 + */ + TITAN_EMBED_TEXT_V1("amazon.titan-embed-text-v1"), + /** + * amazon.titan-embed-text-v2 + */ + TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; + + private final String id; + + TitanEmbeddingModel(String value) { + this.id = value; + } + + /** + * @return The model id. + */ + public String id() { + return this.id; + } + } + /** * Titan Embedding request parameters. * @@ -143,44 +178,8 @@ public class TitanEmbeddingBedrockApi extends @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, @JsonProperty("embeddingsByType") Map embeddingsByType, @JsonProperty("message") Object message) { - - - } - /** - * Titan Embedding model ids. - */ - public enum TitanEmbeddingModel { - /** - * amazon.titan-embed-image-v1 - */ - TITAN_EMBED_IMAGE_V1("amazon.titan-embed-image-v1"), - /** - * amazon.titan-embed-text-v1 - */ - TITAN_EMBED_TEXT_V1("amazon.titan-embed-text-v1"), - /** - * amazon.titan-embed-text-v2 - */ - TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - TitanEmbeddingModel(String value) { - this.id = value; - } - } - - @Override - public TitanEmbeddingResponse embedding(TitanEmbeddingRequest request) { - return this.internalInvocation(request, TitanEmbeddingResponse.class); } } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java index e0893036b..0ccd002d7 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.time.Duration; @@ -22,7 +23,6 @@ import java.util.Map; import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; - import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -33,11 +33,11 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsPro import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -70,8 +70,8 @@ class BedrockAnthropicChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -98,12 +98,12 @@ class BedrockAnthropicChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -140,16 +140,13 @@ class BedrockAnthropicChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Disabled @Test void beanOutputConverterRecords() { @@ -165,7 +162,7 @@ class BedrockAnthropicChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -186,7 +183,7 @@ class BedrockAnthropicChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -202,6 +199,10 @@ class BedrockAnthropicChatModelIT { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java index 3cf8b344b..37eaa0b57 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic; import java.time.Duration; @@ -38,7 +39,7 @@ public class BedrockAnthropicCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockAnthropicChatModel(anthropicChatApi, + var client = new BedrockAnthropicChatModel(this.anthropicChatApi, AnthropicChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java index 8f0efe45a..3638104cb 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic.api; import java.time.Duration; @@ -28,11 +29,13 @@ import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; -import static org.assertj.core.api.Assertions.assertThat;; +import static org.assertj.core.api.Assertions.assertThat; + +; /** * @author Christian Tzolov @@ -57,7 +60,7 @@ public class AnthropicChatBedrockApiIT { .withTopK(10) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); System.out.println(response.completion()); assertThat(response).isNotNull(); @@ -67,7 +70,7 @@ public class AnthropicChatBedrockApiIT { assertThat(response.stop()).isEqualTo("\n\nHuman:"); assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -81,7 +84,7 @@ public class AnthropicChatBedrockApiIT { .withStopSequences(List.of("\n\nHuman:")) .build(); - Flux responseStream = anthropicChatApi.chatCompletionStream(request); + Flux responseStream = this.anthropicChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java index 22644b42b..d2d906035 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3; import java.io.IOException; @@ -32,18 +33,18 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsPro import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -72,8 +73,8 @@ class BedrockAnthropic3ChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -100,12 +101,12 @@ class BedrockAnthropic3ChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -142,16 +143,13 @@ class BedrockAnthropic3ChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -166,7 +164,7 @@ class BedrockAnthropic3ChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -187,7 +185,7 @@ class BedrockAnthropic3ChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -211,7 +209,7 @@ class BedrockAnthropic3ChatModelIT { var userMessage = new UserMessage("Explain what do you see o this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); @@ -222,12 +220,16 @@ class BedrockAnthropic3ChatModelIT { Anthropic3ChatOptions chatOptions = new Anthropic3ChatOptions(); chatOptions.setStopSequences(List.of("Hello")); - var response = chatModel.call(new Prompt("hi", chatOptions)); + var response = this.chatModel.call(new Prompt("hi", chatOptions)); assertThat(response).isNotNull(); assertThat(response.getResults()).isEmpty(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java index 75551ca1c..bb76ae3d8 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.bedrock.anthropic3; -import org.junit.jupiter.api.Test; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; -import org.springframework.ai.chat.prompt.Prompt; -import software.amazon.awssdk.regions.Region; +package org.springframework.ai.bedrock.anthropic3; import java.time.Duration; import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.chat.prompt.Prompt; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -37,7 +39,7 @@ public class BedrockAnthropic3CreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockAnthropic3ChatModel(anthropicChatApi, + var client = new BedrockAnthropic3ChatModel(this.anthropicChatApi, Anthropic3ChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java index 48b89af37..55b054889 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,27 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.anthropic3.api; +import java.time.Duration; +import java.util.List; +import java.util.stream.Collectors; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import java.time.Duration; -import java.util.List; -import java.util.stream.Collectors; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; @@ -63,9 +65,9 @@ public class Anthropic3ChatBedrockApiIT { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - logger.info("" + response.content()); + this.logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); @@ -75,7 +77,7 @@ public class Anthropic3ChatBedrockApiIT { assertThat(response.usage().inputTokens()).isGreaterThan(10); assertThat(response.usage().outputTokens()).isGreaterThan(100); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -103,9 +105,9 @@ public class Anthropic3ChatBedrockApiIT { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - logger.info("" + response.content()); + this.logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); assertThat(response.content().get(0).text()).contains("Blackbeard"); @@ -114,7 +116,7 @@ public class Anthropic3ChatBedrockApiIT { assertThat(response.usage().inputTokens()).isGreaterThan(30); assertThat(response.usage().outputTokens()).isGreaterThan(200); - logger.info("" + response); + this.logger.info("" + response); } @Test @@ -129,7 +131,7 @@ public class Anthropic3ChatBedrockApiIT { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - Flux responseStream = anthropicChatApi + Flux responseStream = this.anthropicChatApi .chatCompletionStream(request); List responses = responseStream.collectList().block(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index 92060ef5b..f3f33bbfc 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.aot; +import java.util.Arrays; +import java.util.List; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; @@ -26,10 +32,6 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.List; -import java.util.Set; -import java.util.Arrays; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java index b6c0027da..a1e3d7a3f 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -45,7 +46,7 @@ public class BedrockCohereChatCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockCohereChatModel(chatApi, + var client = new BedrockCohereChatModel(this.chatApi, BedrockCohereChatOptions.builder() .withTemperature(66.6) .withTopK(66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java index 5da9f8670..340b541b0 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -30,11 +31,11 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -65,8 +66,8 @@ class BedrockCohereChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -95,10 +96,10 @@ class BedrockCohereChatModelIT { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -134,16 +135,13 @@ class BedrockCohereChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -157,7 +155,7 @@ class BedrockCohereChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -178,7 +176,7 @@ class BedrockCohereChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -194,6 +192,10 @@ class BedrockCohereChatModelIT { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java index 194b657ed..03d3f0145 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere; import java.time.Duration; @@ -46,17 +47,17 @@ class BedrockCohereEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -64,13 +65,13 @@ class BedrockCohereEmbeddingModelIT { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void embeddingWthOptions() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), BedrockCohereEmbeddingOptions.builder().withInputType(InputType.SEARCH_DOCUMENT).build())); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -79,7 +80,7 @@ class BedrockCohereEmbeddingModelIT { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @SpringBootConfiguration diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 287eec21f..27c11af67 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere.api; import java.time.Duration; @@ -32,7 +33,9 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChat import org.springframework.ai.model.ModelOptionsUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy;; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +; /** * @author Christian Tzolov @@ -86,7 +89,7 @@ public class CohereChatBedrockApiIT { .withTruncate(Truncate.NONE) .build(); - CohereChatResponse response = cohereChatApi.chatCompletion(request); + CohereChatResponse response = this.cohereChatApi.chatCompletion(request); assertThat(response).isNotNull(); assertThat(response.prompt()).isEqualTo(request.prompt()); @@ -111,7 +114,7 @@ public class CohereChatBedrockApiIT { .withTruncate(Truncate.NONE) .build(); - Flux responseStream = cohereChatApi.chatCompletionStream(request); + Flux responseStream = this.cohereChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); @@ -132,7 +135,7 @@ public class CohereChatBedrockApiIT { .withStream(true) .build(); - assertThatThrownBy(() -> cohereChatApi.chatCompletion(streamRequest)) + assertThatThrownBy(() -> this.cohereChatApi.chatCompletion(streamRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("The request must be configured to return the complete response!"); @@ -141,7 +144,7 @@ public class CohereChatBedrockApiIT { .withStream(false) .build(); - assertThatThrownBy(() -> cohereChatApi.chatCompletionStream(notStreamRequest)) + assertThatThrownBy(() -> this.cohereChatApi.chatCompletionStream(notStreamRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("The request must be configured to stream the response!"); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java index 83afec90d..e8154344b 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.cohere.api; import java.time.Duration; @@ -49,7 +50,7 @@ public class CohereEmbeddingBedrockApiIT { List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.NONE); - CohereEmbeddingResponse response = api.embedding(request); + CohereEmbeddingResponse response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); @@ -64,7 +65,7 @@ public class CohereEmbeddingBedrockApiIT { List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.START); - CohereEmbeddingResponse response = api.embedding(request); + CohereEmbeddingResponse response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); @@ -74,7 +75,7 @@ public class CohereEmbeddingBedrockApiIT { request = new CohereEmbeddingRequest(List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.END); - response = api.embedding(request); + response = this.api.embedding(request); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java index c7a641977..c0919cd03 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -29,10 +29,10 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsPro import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -61,12 +61,12 @@ class BedrockAi21Jurassic2ChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -83,7 +83,7 @@ class BedrockAi21Jurassic2ChatModelIT { UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄 ?"); Prompt prompt = new Prompt(List.of(userMessage), options); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).matches(content -> content.contains("😄")); } @@ -98,12 +98,12 @@ class BedrockAi21Jurassic2ChatModelIT { .build(); UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄?"); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), options); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).doesNotContain("😄"); } @@ -120,7 +120,7 @@ class BedrockAi21Jurassic2ChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -131,12 +131,12 @@ class BedrockAi21Jurassic2ChatModelIT { @Test void simpleChatResponse() { UserMessage userMessage = new UserMessage("Tell me a joke about AI."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("AI"); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java index f3dedde9e..aa16faa71 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.jurassic2.api; import java.time.Duration; @@ -20,7 +21,6 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -50,7 +50,7 @@ public class Ai21Jurassic2ChatBedrockApiIT { new Ai21Jurassic2ChatRequest.FloatScalePenalty(0.5f, true, true, true, true, true), new Ai21Jurassic2ChatRequest.IntegerScalePenalty(1, true, true, true, true, true)); - Ai21Jurassic2ChatResponse response = api.chatCompletion(request); + Ai21Jurassic2ChatResponse response = this.api.chatCompletion(request); assertThat(response).isNotNull(); assertThat(response.completions()).isNotEmpty(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java index c9239875e..168250f9d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; import java.time.Duration; @@ -30,11 +31,11 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -65,8 +66,8 @@ class BedrockLlamaChatModelIT { @Test void multipleStreamAttempts() { - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?"))); - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -93,12 +94,12 @@ class BedrockLlamaChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -134,16 +135,13 @@ class BedrockLlamaChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -158,7 +156,7 @@ class BedrockLlamaChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -179,7 +177,7 @@ class BedrockLlamaChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -195,6 +193,10 @@ class BedrockLlamaChatModelIT { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java index 3add11d14..48c81556b 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama; +import java.time.Duration; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -26,8 +28,6 @@ import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.ai.chat.prompt.Prompt; -import java.time.Duration; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -45,7 +45,7 @@ public class BedrockLlamaCreateRequestTests { @Test public void createRequestWithChatOptions() { - var client = new BedrockLlamaChatModel(api, + var client = new BedrockLlamaChatModel(this.api, BedrockLlamaChatOptions.builder().withTemperature(66.6).withMaxGenLen(666).withTopP(0.66).build()); var request = client.createRequest(new Prompt("Test message content")); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java index 48844670c..664e02194 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,23 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.llama.api; import java.time.Duration; import java.util.List; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; +import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -53,7 +53,7 @@ public class LlamaChatBedrockApiIT { .withMaxGenLen(20) .build(); - LlamaChatResponse response = llamaChatApi.chatCompletion(request); + LlamaChatResponse response = this.llamaChatApi.chatCompletion(request); System.out.println(response.generation()); assertThat(response).isNotNull(); @@ -68,7 +68,7 @@ public class LlamaChatBedrockApiIT { public void chatCompletionStream() { LlamaChatRequest request = new LlamaChatRequest("Hello, my name is", 0.9, 0.9, 20); - Flux responseStream = llamaChatApi.chatCompletionStream(request); + Flux responseStream = this.llamaChatApi.chatCompletionStream(request); List responses = responseStream.collectList().block(); assertThat(responses).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java index 90705ecc2..81c62d70b 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.time.Duration; @@ -41,7 +42,7 @@ public class BedrockTitanChatModelCreateRequestTests { @Test public void createRequestWithChatOptions() { - var model = new BedrockTitanChatModel(api, + var model = new BedrockTitanChatModel(this.api, BedrockTitanChatOptions.builder() .withTemperature(66.6) .withTopP(0.66) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java index 45bc6fed1..b96991c38 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java @@ -31,11 +31,11 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java index 7670e8db9..ae4cdb6e3 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.io.IOException; @@ -20,9 +21,9 @@ import java.time.Duration; import java.util.Base64; import java.util.List; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -37,8 +38,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.databind.ObjectMapper; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest @@ -51,12 +50,12 @@ class BedrockTitanEmbeddingModelIT { @Test void singleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test @@ -65,12 +64,12 @@ class BedrockTitanEmbeddingModelIT { byte[] image = new DefaultResourceLoader().getResource("classpath:/spring_framework.png") .getContentAsByteArray(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)), BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @SpringBootConfiguration diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java index 453e84490..094f182bb 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.time.Duration; @@ -53,14 +54,14 @@ public class TitanChatBedrockApiIT { @Test public void chatCompletion() { - TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); + TitanChatResponse response = this.titanBedrockApi.chatCompletion(this.titanChatRequest); assertThat(response.results()).hasSize(1); assertThat(response.results().get(0).outputText()).contains("Blackbeard"); } @Test public void chatCompletionStream() { - Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); + Flux response = this.titanBedrockApi.chatCompletionStream(this.titanChatRequest); List results = response.collectList().block(); assertThat(results.stream() diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java index 4f7813b2d..f27a56bf6 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan.api; import java.io.IOException; import java.time.Duration; import java.util.Base64; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -30,8 +31,6 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEm import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; import org.springframework.core.io.DefaultResourceLoader; -import com.fasterxml.jackson.databind.ObjectMapper; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-huggingface/pom.xml b/models/spring-ai-huggingface/pom.xml index 4dac3c3ef..9c74fb326 100644 --- a/models/spring-ai-huggingface/pom.xml +++ b/models/spring-ai-huggingface/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index 8a8c8d92b..affd1a329 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.huggingface; import java.util.ArrayList; @@ -25,15 +26,15 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; import org.springframework.ai.huggingface.invoker.ApiClient; import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails; import org.springframework.ai.huggingface.model.GenerateParameters; import org.springframework.ai.huggingface.model.GenerateRequest; import org.springframework.ai.huggingface.model.GenerateResponse; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; -import org.springframework.ai.chat.prompt.Prompt; /** * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference @@ -92,14 +93,15 @@ public class HuggingfaceChatModel implements ChatModel { generateRequest.setInputs(prompt.getContents()); GenerateParameters generateParameters = new GenerateParameters(); // TODO - need to expose API to set parameters per call. - generateParameters.setMaxNewTokens(maxNewTokens); + generateParameters.setMaxNewTokens(this.maxNewTokens); generateRequest.setParameters(generateParameters); GenerateResponse generateResponse = this.textGenApi.generate(generateRequest); String generatedText = generateResponse.getGeneratedText(); List generations = new ArrayList<>(); AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails(); - Map detailsMap = objectMapper.convertValue(allOfGenerateResponseDetails, + Map detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails, new TypeReference>() { + }); Generation generation = new Generation(generatedText, detailsMap); generations.add(generation); @@ -111,7 +113,7 @@ public class HuggingfaceChatModel implements ChatModel { * @return The maximum number of new tokens. */ public int getMaxNewTokens() { - return maxNewTokens; + return this.maxNewTokens; } /** diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java index 8e2a90d9d..5f933a09c 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.huggingface; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 48f32b6ba..9106ae98d 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.huggingface.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.huggingface.client; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.HuggingfaceChatModel; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") @@ -44,7 +46,7 @@ public class ClientIT { [/INST] """; Prompt prompt = new Prompt(mistral7bInstruct); - ChatResponse chatResponse = huggingfaceChatModel.call(prompt); + ChatResponse chatResponse = this.huggingfaceChatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); String expectedResponse = """ { diff --git a/models/spring-ai-minimax/pom.xml b/models/spring-ai-minimax/pom.xml index ee8ea6c07..85824eb49 100644 --- a/models/spring-ai-minimax/pom.xml +++ b/models/spring-ai-minimax/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index e8aca8bdb..7d677db41 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -61,15 +72,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE; @@ -89,16 +91,16 @@ public class MiniMaxChatModel extends AbstractToolCallSupport implements ChatMod private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - /** - * The default options used for the chat completion requests. - */ - private final MiniMaxChatOptions defaultOptions; - /** * The retry template used to retry the MiniMax API calls. */ public final RetryTemplate retryTemplate; + /** + * The default options used for the chat completion requests. + */ + private final MiniMaxChatOptions defaultOptions; + /** * Low-level access to the MiniMax API. */ @@ -174,6 +176,40 @@ public class MiniMaxChatModel extends AbstractToolCallSupport implements ChatMod this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + // the MiniMax's stream function calls response are really odd + // occasionally, tool call might get split. + // for example, id empty means the previous tool call is not finished, + // the toolCalls: + // [{id:'1',function:{name:'a'}},{id:'',function:{arguments:'[1]'}}] + // these need to be merged into [{id:'1', name:'a', arguments:'[1]'}] + // it worked before, maybe the model provider made some adjustments + .reduce(new ArrayList<>(), (acc, current) -> { + if (!acc.isEmpty() && current.id().isEmpty()) { + AssistantMessage.ToolCall prev = acc.get(acc.size() - 1); + acc.set(acc.size() - 1, new AssistantMessage.ToolCall(prev.id(), prev.type(), prev.name(), + current.function().arguments())); + } + else { + AssistantMessage.ToolCall currentToolCall = new AssistantMessage.ToolCall(current.id(), + current.type(), current.function().name(), current.function().arguments()); + acc.add(currentToolCall); + } + return acc; + }, (acc1, acc2) -> { + acc1.addAll(acc2); + return acc1; + }); + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -376,40 +412,6 @@ public class MiniMaxChatModel extends AbstractToolCallSupport implements ChatMod return new Generation(assistantMessage, generationMetadata); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - // the MiniMax's stream function calls response are really odd - // occasionally, tool call might get split. - // for example, id empty means the previous tool call is not finished, - // the toolCalls: - // [{id:'1',function:{name:'a'}},{id:'',function:{arguments:'[1]'}}] - // these need to be merged into [{id:'1', name:'a', arguments:'[1]'}] - // it worked before, maybe the model provider made some adjustments - .reduce(new ArrayList<>(), (acc, current) -> { - if (!acc.isEmpty() && current.id().isEmpty()) { - AssistantMessage.ToolCall prev = acc.get(acc.size() - 1); - acc.set(acc.size() - 1, new AssistantMessage.ToolCall(prev.id(), prev.type(), prev.name(), - current.function().arguments())); - } - else { - AssistantMessage.ToolCall currentToolCall = new AssistantMessage.ToolCall(current.id(), - current.type(), current.function().name(), current.function().arguments()); - acc.add(currentToolCall); - } - return acc; - }, (acc1, acc2) -> { - acc1.addAll(acc2); - return acc1; - }); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 7762b4a4f..0e10ca20e 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.minimax; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.minimax.api.MiniMaxApi; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.boot.context.properties.NestedConfigurationProperty; -import org.springframework.util.Assert; +package org.springframework.ai.minimax; import java.util.ArrayList; import java.util.HashSet; @@ -32,6 +22,18 @@ import java.util.List; import java.util.Map; import java.util.Set; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.minimax.api.MiniMaxApi; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + /** * MiniMaxChatOptions represents the options for performing chat completion using the * MiniMax API. It provides methods to set and retrieve various options like model, @@ -157,6 +159,356 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions { return new Builder(); } + public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withN(fromOptions.getN()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withMaskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public MiniMaxApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(MiniMaxApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Boolean getMaskSensitiveInfo() { + return this.maskSensitiveInfo; + } + + public void setMaskSensitiveInfo(Boolean maskSensitiveInfo) { + this.maskSensitiveInfo = maskSensitiveInfo; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.seed == null) ? 0 : this.seed.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.maskSensitiveInfo == null) ? 0 : this.maskSensitiveInfo.hashCode()); + result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); + result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + MiniMaxChatOptions other = (MiniMaxChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) { + return false; + } + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { + return false; + } + if (this.maxTokens == null) { + if (other.maxTokens != null) { + return false; + } + } + else if (!this.maxTokens.equals(other.maxTokens)) { + return false; + } + if (this.n == null) { + if (other.n != null) { + return false; + } + } + else if (!this.n.equals(other.n)) { + return false; + } + if (this.presencePenalty == null) { + if (other.presencePenalty != null) { + return false; + } + } + else if (!this.presencePenalty.equals(other.presencePenalty)) { + return false; + } + if (this.responseFormat == null) { + if (other.responseFormat != null) { + return false; + } + } + else if (!this.responseFormat.equals(other.responseFormat)) { + return false; + } + if (this.seed == null) { + if (other.seed != null) { + return false; + } + } + else if (!this.seed.equals(other.seed)) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + if (this.maskSensitiveInfo == null) { + if (other.maskSensitiveInfo != null) { + return false; + } + } + else if (!this.maskSensitiveInfo.equals(other.maskSensitiveInfo)) { + return false; + } + if (this.tools == null) { + if (other.tools != null) { + return false; + } + } + else if (!this.tools.equals(other.tools)) { + return false; + } + if (this.toolChoice == null) { + if (other.toolChoice != null) { + return false; + } + } + else if (!this.toolChoice.equals(other.toolChoice)) { + return false; + } + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) { + return false; + } + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { + return false; + } + + if (this.toolContext == null) { + if (other.toolContext != null) { + return false; + } + } + else if (!this.toolContext.equals(other.toolContext)) { + return false; + } + + return true; + } + + @Override + public MiniMaxChatOptions copy() { + return fromOptions(this); + } + public static class Builder { protected MiniMaxChatOptions options; @@ -272,321 +624,4 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - - public void setFrequencyPenalty(Double frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - } - - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - - public void setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - } - - public Integer getN() { - return this.n; - } - - public void setN(Integer n) { - this.n = n; - } - - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - - public void setPresencePenalty(Double presencePenalty) { - this.presencePenalty = presencePenalty; - } - - public MiniMaxApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { - return this.responseFormat; - } - - public void setResponseFormat(MiniMaxApi.ChatCompletionRequest.ResponseFormat responseFormat) { - this.responseFormat = responseFormat; - } - - public Integer getSeed() { - return this.seed; - } - - public void setSeed(Integer seed) { - this.seed = seed; - } - - @Override - @JsonIgnore - public List getStopSequences() { - return getStop(); - } - - @JsonIgnore - public void setStopSequences(List stopSequences) { - setStop(stopSequences); - } - - public List getStop() { - return this.stop; - } - - public void setStop(List stop) { - this.stop = stop; - } - - @Override - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - @Override - public Double getTopP() { - return this.topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public Boolean getMaskSensitiveInfo() { - return maskSensitiveInfo; - } - - public void setMaskSensitiveInfo(Boolean maskSensitiveInfo) { - this.maskSensitiveInfo = maskSensitiveInfo; - } - - public List getTools() { - return this.tools; - } - - public void setTools(List tools) { - this.tools = tools; - } - - public String getToolChoice() { - return this.toolChoice; - } - - public void setToolChoice(String toolChoice) { - this.toolChoice = toolChoice; - } - - @Override - public List getFunctionCallbacks() { - return this.functionCallbacks; - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; - } - - @Override - public Set getFunctions() { - return functions; - } - - public void setFunctions(Set functionNames) { - this.functions = functionNames; - } - - @Override - @JsonIgnore - public Integer getTopK() { - return null; - } - - @Override - public Boolean getProxyToolCalls() { - return this.proxyToolCalls; - } - - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; - } - - @Override - public Map getToolContext() { - return this.toolContext; - } - - @Override - public void setToolContext(Map toolContext) { - this.toolContext = toolContext; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((seed == null) ? 0 : seed.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((maskSensitiveInfo == null) ? 0 : maskSensitiveInfo.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - MiniMaxChatOptions other = (MiniMaxChatOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) - return false; - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) - return false; - if (this.maxTokens == null) { - if (other.maxTokens != null) - return false; - } - else if (!this.maxTokens.equals(other.maxTokens)) - return false; - if (this.n == null) { - if (other.n != null) - return false; - } - else if (!this.n.equals(other.n)) - return false; - if (this.presencePenalty == null) { - if (other.presencePenalty != null) - return false; - } - else if (!this.presencePenalty.equals(other.presencePenalty)) - return false; - if (this.responseFormat == null) { - if (other.responseFormat != null) - return false; - } - else if (!this.responseFormat.equals(other.responseFormat)) - return false; - if (this.seed == null) { - if (other.seed != null) - return false; - } - else if (!this.seed.equals(other.seed)) - return false; - if (this.stop == null) { - if (other.stop != null) - return false; - } - else if (!stop.equals(other.stop)) - return false; - if (this.temperature == null) { - if (other.temperature != null) - return false; - } - else if (!this.temperature.equals(other.temperature)) - return false; - if (this.topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (this.maskSensitiveInfo == null) { - if (other.maskSensitiveInfo != null) - return false; - } - else if (!maskSensitiveInfo.equals(other.maskSensitiveInfo)) - return false; - if (this.tools == null) { - if (other.tools != null) - return false; - } - else if (!tools.equals(other.tools)) - return false; - if (this.toolChoice == null) { - if (other.toolChoice != null) - return false; - } - else if (!toolChoice.equals(other.toolChoice)) - return false; - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) - return false; - } - else if (!proxyToolCalls.equals(other.proxyToolCalls)) - return false; - - if (this.toolContext == null) { - if (other.toolContext != null) - return false; - } - else if (!toolContext.equals(other.toolContext)) - return false; - - return true; - } - - @Override - public MiniMaxChatOptions copy() { - return fromOptions(this); - } - - public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withSeed(fromOptions.getSeed()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withMaskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index 0ba752c38..1882607e0 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.ArrayList; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -39,9 +44,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * MiniMax Embedding Model implementation. * diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java index d265e2dd6..9dffe18d1 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -42,6 +44,21 @@ public class MiniMaxEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected MiniMaxEmbeddingOptions options; @@ -61,19 +78,4 @@ public class MiniMaxEmbeddingOptions implements EmbeddingOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java index 129eb2d76..01d7fb620 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.aot; import org.springframework.ai.minimax.api.MiniMaxApi; diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java index c254d07d6..3216f6940 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; import java.util.List; @@ -21,6 +22,13 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -35,14 +43,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonValue; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - // @formatter:off /** * Single class implementation of the MiniMax Chat Completion API and @@ -62,6 +62,8 @@ public class MiniMaxApi { private final WebClient webClient; + private final MiniMaxStreamFunctionCallingHelper chunkMerger = new MiniMaxStreamFunctionCallingHelper(); + /** * Create a new chat completion api with default base URL. * @@ -119,6 +121,99 @@ public class MiniMaxApi { .build(); } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/text/chatcompletion_v2") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/text/chatcompletion_v2") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * + * @param embeddingRequest The embedding request. + * @return Returns {@link EmbeddingList}. + * + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + + return this.restClient.post() + .uri("/v1/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + /** * MiniMax Chat Completion Models: * MiniMax Model. @@ -141,7 +236,7 @@ public class MiniMaxApi { } public String getValue() { - return value; + return this.value; } @Override @@ -150,6 +245,85 @@ public class MiniMaxApi { } } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") TOOL_CALL + } + + /** + * MiniMax Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 1536 + */ + Embo_01("embo-01"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + + /** + * MiniMax Embeddings Types + */ + public enum EmbeddingType { + + /** + * DB, used to generate vectors and store them in the library (as retrieved text) + */ + DB("db"), + + /** + * Query, used to generate vectors for queries (when used as retrieval text) + */ + Query("query"); + + @JsonValue + public final String value; + + EmbeddingType(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -382,6 +556,15 @@ public class MiniMaxApi { @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls) { + /** + * Create a chat completion message with the given content and role. All other fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -395,15 +578,6 @@ public class MiniMaxApi { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. */ @@ -441,22 +615,6 @@ public class MiniMaxApi { @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { - /** - * @param url Either a URL of the image or the base64 encoded image data. - * The base64 encoded image data must have a special prefix in the following format: - * "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl( - @JsonProperty("url") String url, - @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } - /** * Shortcut constructor for a text content. * @param text The text content of the message. @@ -472,6 +630,22 @@ public class MiniMaxApi { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } } /** * The relevant tool call. @@ -501,43 +675,6 @@ public class MiniMaxApi { } } - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") TOOL_CALL - } - /** * Represents a chat completion response returned by model, based on the provided input. * @@ -689,118 +826,6 @@ public class MiniMaxApi { } } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/text/chatcompletion_v2") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private final MiniMaxStreamFunctionCallingHelper chunkMerger = new MiniMaxStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * - * @param chatRequest The chat completion request. Must have the stream property set to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v1/text/chatcompletion_v2") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - .flatMap(mono -> mono); - } - - /** - * MiniMax Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 1536 - */ - Embo_01("embo-01"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - - /** - * MiniMax Embeddings Types - */ - public enum EmbeddingType { - - /** - * DB, used to generate vectors and store them in the library (as retrieved text) - */ - DB("db"), - - /** - * Query, used to generate vectors for queries (when used as retrieval text) - */ - Query("query"); - - @JsonValue - public final String value; - - EmbeddingType(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Creates an embedding vector representing the input text. * @@ -890,30 +915,5 @@ public class MiniMaxApi { @JsonProperty("total_tokens") Integer totalTokens) { } - /** - * Creates an embedding vector representing the input text or token array. - * - * @param embeddingRequest The embedding request. - * @return Returns {@link EmbeddingList}. - * - */ - public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.texts(), "The input can not be null."); - - Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); - - return this.restClient.post() - .uri("/v1/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java index a8ed1b34a..c83d1a448 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.minimax.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java index 82b2eca12..24a71ec0f 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionFinishReason; @@ -26,9 +30,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.LogProbs; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java index a720be0e4..cb8a5a74a 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -26,10 +27,6 @@ import org.springframework.util.Assert; */ public class MiniMaxUsage implements Usage { - public static MiniMaxUsage from(MiniMaxApi.Usage usage) { - return new MiniMaxUsage(usage); - } - private final MiniMaxApi.Usage usage; protected MiniMaxUsage(MiniMaxApi.Usage usage) { @@ -37,6 +34,10 @@ public class MiniMaxUsage implements Usage { this.usage = usage; } + public static MiniMaxUsage from(MiniMaxApi.Usage usage) { + return new MiniMaxUsage(usage); + } + protected MiniMaxApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java index 221a6bccb..5c50ddf30 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.minimax.api.MockWeatherService; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java index 8a7914da9..0493fba62 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java index 60812302d..52cbd95ae 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.List; +import java.util.Objects; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage; @@ -24,10 +30,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest; import org.springframework.ai.minimax.api.MiniMaxApi.EmbeddingList; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; import static org.assertj.core.api.Assertions.assertThat; @@ -42,7 +44,7 @@ public class MiniMaxApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = miniMaxApi + ResponseEntity response = this.miniMaxApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-air", 0.7, false)); assertThat(response).isNotNull(); @@ -52,7 +54,7 @@ public class MiniMaxApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = miniMaxApi + Flux response = this.miniMaxApi .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-air", 0.7, true)); assertThat(response).isNotNull(); @@ -61,7 +63,8 @@ public class MiniMaxApiIT { @Test void embeddings() { - ResponseEntity response = miniMaxApi.embeddings(new MiniMaxApi.EmbeddingRequest("Hello world")); + ResponseEntity response = this.miniMaxApi + .embeddings(new MiniMaxApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); assertThat(Objects.requireNonNull(response.getBody()).vectors()).hasSize(1); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java index d878a898d..fbbf90066 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.minimax.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role; @@ -31,10 +36,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest.ToolC import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -49,6 +50,15 @@ public class MiniMaxApiToolFunctionCallIT { MiniMaxApi miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -89,7 +99,7 @@ public class MiniMaxApiToolFunctionCallIT { org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_Chat.getValue(), List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = miniMaxApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.miniMaxApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -108,7 +118,7 @@ public class MiniMaxApiToolFunctionCallIT { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -119,9 +129,9 @@ public class MiniMaxApiToolFunctionCallIT { var functionResponseRequest = new ChatCompletionRequest(messages, org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_Chat.getValue(), 0.5); - ResponseEntity chatCompletion2 = miniMaxApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.miniMaxApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -146,7 +156,7 @@ public class MiniMaxApiToolFunctionCallIT { org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = miniMaxApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.miniMaxApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -158,13 +168,4 @@ public class MiniMaxApiToolFunctionCallIT { assertThat(assistantMessage.content()).contains("40"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java index b20bada56..d2099ef84 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.minimax.MiniMaxChatModel; @@ -41,10 +47,6 @@ import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -58,25 +60,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class MiniMaxRetryTests { - private class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -89,13 +72,14 @@ public class MiniMaxRetryTests { @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED, - MiniMaxEmbeddingOptions.builder().build(), retryTemplate); + this.chatModel = new MiniMaxChatModel(this.miniMaxApi, MiniMaxChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new MiniMaxEmbeddingModel(this.miniMaxApi, MetadataMode.EMBED, + MiniMaxEmbeddingOptions.builder().build(), this.retryTemplate); } @Test @@ -106,24 +90,24 @@ public class MiniMaxRetryTests { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, null, new MiniMaxApi.Usage(10, 10, 10)); - when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxChatNonTransientError() { - when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -134,24 +118,24 @@ public class MiniMaxRetryTests { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null); - when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxChatStreamNonTransientError() { - when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @Test @@ -159,25 +143,45 @@ public class MiniMaxRetryTests { EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10); - when(miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void miniMaxEmbeddingNonTransientError() { - when(miniMaxApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java index d2f4a9e53..0d3b16452 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ public class MockWeatherService implements Function chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -122,7 +124,7 @@ public class MiniMaxChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index 23daaf5cb..024de6f26 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -1,9 +1,32 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.minimax.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -13,12 +36,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat; @@ -42,7 +59,7 @@ public class MiniMaxChatOptionsTests { List messages = new ArrayList<>(List.of(userMessage)); // markSensitiveInfo is enabled by default - ChatResponse response = chatModel.call(new Prompt(messages)); + ChatResponse response = this.chatModel.call(new Prompt(messages)); String responseContent = response.getResult().getOutput().getContent(); assertThat(responseContent).contains("133-**"); @@ -50,7 +67,7 @@ public class MiniMaxChatOptionsTests { var chatOptions = MiniMaxChatOptions.builder().withMaskSensitiveInfo(false).build(); - ChatResponse unmaskResponse = chatModel.call(new Prompt(messages, chatOptions)); + ChatResponse unmaskResponse = this.chatModel.call(new Prompt(messages, chatOptions)); String unmaskResponseContent = unmaskResponse.getResult().getOutput().getContent(); assertThat(unmaskResponseContent).contains("133-12345678"); @@ -80,7 +97,7 @@ public class MiniMaxChatOptionsTests { .withTools(functionTool) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, options)); + ChatResponse response = this.chatModel.call(new Prompt(messages, options)); String responseContent = response.getResult().getOutput().getContent(); assertThat(responseContent).contains("40"); @@ -110,7 +127,7 @@ public class MiniMaxChatOptionsTests { .withTools(functionTool) .build(); - Flux response = chatModel.stream(new Prompt(messages, options)); + Flux response = this.chatModel.stream(new Prompt(messages, options)); String content = Objects.requireNonNull(response.collectList().block()) .stream() .map(ChatResponse::getResults) diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java index 551c9fef6..2ce7e934a 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.embedding; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.minimax.MiniMaxEmbeddingModel; import org.springframework.ai.minimax.MiniMaxTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,27 +41,27 @@ class EmbeddingIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1536); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java index 1c0a8bfb0..51336796c 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.minimax.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -62,13 +64,13 @@ public class MiniMaxEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/pom.xml b/models/spring-ai-mistral-ai/pom.xml index ea6ba2097..5cae1c540 100644 --- a/models/spring-ai-mistral-ai/pom.xml +++ b/models/spring-ai-mistral-ai/pom.xml @@ -1,33 +1,50 @@ - - 4.0.0 - - org.springframework.ai - spring-ai - 1.0.0-SNAPSHOT - ../../pom.xml - - spring-ai-mistral-ai - jar - Spring AI Model - Mistral AI - Mistral AI models support - https://github.com/spring-projects/spring-ai + - - https://github.com/spring-projects/spring-ai - git://github.com/spring-projects/spring-ai.git - git@github.com:spring-projects/spring-ai.git - + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-mistral-ai + jar + Spring AI Model - Mistral AI + Mistral AI models support + https://github.com/spring-projects/spring-ai - + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + - - - org.springframework.ai - spring-ai-core - ${project.parent.version} - + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + org.springframework.ai @@ -35,24 +52,24 @@ ${project.parent.version} - - - org.springframework - spring-context-support - + + + org.springframework + spring-context-support + - - org.springframework.boot - spring-boot-starter-logging - + + org.springframework.boot + spring-boot-starter-logging + - - - org.springframework.ai - spring-ai-test - ${project.version} - test - + + + org.springframework.ai + spring-ai-test + ${project.version} + test + io.micrometer @@ -60,6 +77,6 @@ test - + diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index bad45cdd4..c699962f9 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.HashSet; @@ -26,13 +27,20 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -59,9 +67,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Represents a Mistral AI Chat Model. * @@ -74,10 +79,10 @@ import reactor.core.publisher.Mono; */ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { - private final Logger logger = LoggerFactory.getLogger(getClass()); - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private final Logger logger = LoggerFactory.getLogger(getClass()); + /** * The default options used for the chat completion requests. */ @@ -140,6 +145,17 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM this.observationRegistry = observationRegistry; } + public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { + Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); + MistralAiUsage usage = MistralAiUsage.from(result.usage()); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("created", result.created()) + .build(); + } + @Override public ChatResponse call(Prompt prompt) { @@ -156,13 +172,13 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = retryTemplate + ResponseEntity completionEntity = this.retryTemplate .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); ChatCompletion chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); + this.logger.warn("No chat completion returned for prompt: {}", prompt); return new ChatResponse(List.of()); } @@ -213,7 +229,7 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - Flux completionChunks = retryTemplate + Flux completionChunks = this.retryTemplate .execute(ctx -> this.mistralAiApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. @@ -250,7 +266,7 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM } } catch (Exception e) { - logger.error("Error processing chat completion", e); + this.logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })); @@ -294,17 +310,6 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM return new Generation(assistantMessage, generationMetadata); } - public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { - Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); - MistralAiUsage usage = MistralAiUsage.from(result.usage()); - return ChatResponseMetadata.builder() - .withId(result.id()) - .withModel(result.model()) - .withUsage(usage) - .withKeyValue("created", result.created()) - .build(); - } - private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { List choices = chunk.choices() .stream() diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index f2d818523..d256265bd 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.ArrayList; @@ -25,10 +26,11 @@ import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; -import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -148,6 +150,215 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions return new Builder(); } + public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withSafePrompt(fromOptions.getSafePrompt()) + .withRandomSeed(fromOptions.getRandomSeed()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Boolean getSafePrompt() { + return this.safePrompt; + } + + public void setSafePrompt(Boolean safePrompt) { + this.safePrompt = safePrompt; + } + + public Integer getRandomSeed() { + return this.randomSeed; + } + + public void setRandomSeed(Integer randomSeed) { + this.randomSeed = randomSeed; + } + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public ToolChoice getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; + } + + @Override + @JsonIgnore + public Double getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Double getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public MistralAiChatOptions copy() { + return fromOptions(this); + } + + @Override + public int hashCode() { + + return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, + this.responseFormat, this.stop, this.tools, this.toolChoice, this.functionCallbacks, this.functions, + this.proxyToolCalls, this.toolContext); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + MistralAiChatOptions other = (MistralAiChatOptions) obj; + + return Objects.equals(this.model, other.model) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.safePrompt, other.safePrompt) + && Objects.equals(this.randomSeed, other.randomSeed) + && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop) + && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) + && Objects.equals(this.functionCallbacks, other.functionCallbacks) + && Objects.equals(this.functions, other.functions) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); + } + public static class Builder { private final MistralAiChatOptions options = new MistralAiChatOptions(); @@ -245,210 +456,4 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - - public void setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - } - - public Boolean getSafePrompt() { - return this.safePrompt; - } - - public void setSafePrompt(Boolean safePrompt) { - this.safePrompt = safePrompt; - } - - public Integer getRandomSeed() { - return this.randomSeed; - } - - public void setRandomSeed(Integer randomSeed) { - this.randomSeed = randomSeed; - } - - public ResponseFormat getResponseFormat() { - return this.responseFormat; - } - - public void setResponseFormat(ResponseFormat responseFormat) { - this.responseFormat = responseFormat; - } - - @Override - @JsonIgnore - public List getStopSequences() { - return getStop(); - } - - @JsonIgnore - public void setStopSequences(List stopSequences) { - setStop(stopSequences); - } - - public List getStop() { - return this.stop; - } - - public void setStop(List stop) { - this.stop = stop; - } - - public void setTools(List tools) { - this.tools = tools; - } - - public List getTools() { - return this.tools; - } - - public void setToolChoice(ToolChoice toolChoice) { - this.toolChoice = toolChoice; - } - - public ToolChoice getToolChoice() { - return this.toolChoice; - } - - @Override - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - @Override - public Double getTopP() { - return this.topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - @Override - public List getFunctionCallbacks() { - return this.functionCallbacks; - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); - this.functionCallbacks = functionCallbacks; - } - - @Override - public Set getFunctions() { - return this.functions; - } - - @Override - public void setFunctions(Set functions) { - Assert.notNull(functions, "Function must not be null"); - this.functions = functions; - } - - @Override - @JsonIgnore - public Double getFrequencyPenalty() { - return null; - } - - @Override - @JsonIgnore - public Double getPresencePenalty() { - return null; - } - - @Override - @JsonIgnore - public Integer getTopK() { - return null; - } - - @Override - public Boolean getProxyToolCalls() { - return this.proxyToolCalls; - } - - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; - } - - @Override - public Map getToolContext() { - return this.toolContext; - } - - @Override - public void setToolContext(Map toolContext) { - this.toolContext = toolContext; - } - - @Override - public MistralAiChatOptions copy() { - return fromOptions(this); - } - - public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) { - return builder().withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withSafePrompt(fromOptions.getSafePrompt()) - .withRandomSeed(fromOptions.getRandomSeed()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStop(fromOptions.getStop()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - - @Override - public int hashCode() { - - return Objects.hash(model, temperature, topP, maxTokens, safePrompt, randomSeed, responseFormat, stop, tools, - toolChoice, functionCallbacks, functions, proxyToolCalls, toolContext); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - - if (obj == null || getClass() != obj.getClass()) - return false; - - MistralAiChatOptions other = (MistralAiChatOptions) obj; - - return Objects.equals(this.model, other.model) && Objects.equals(this.temperature, other.temperature) - && Objects.equals(this.topP, other.topP) && Objects.equals(this.maxTokens, other.maxTokens) - && Objects.equals(this.safePrompt, other.safePrompt) - && Objects.equals(this.randomSeed, other.randomSeed) - && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop) - && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) - && Objects.equals(this.functionCallbacks, other.functionCallbacks) - && Objects.equals(this.functions, other.functions) - && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) - && Objects.equals(this.toolContext, other.toolContext); - } - } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 197882130..c3f0a13e4 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.List; @@ -23,7 +24,13 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java index 7abfa01fc..6409b05ca 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java index cd6bcfa40..6ad65d426 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.aot; import org.springframework.ai.mistralai.api.MistralAiApi; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index a6e156f46..41277a0c0 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api; import java.util.Arrays; @@ -26,12 +27,12 @@ import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.observation.conventions.AiProvider; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; @@ -62,15 +63,17 @@ import org.springframework.web.reactive.function.client.WebClient; */ public class MistralAiApi { - private static final String DEFAULT_BASE_URL = "https://api.mistral.ai"; - public static final String PROVIDER_NAME = AiProvider.MISTRAL_AI.value(); + private static final String DEFAULT_BASE_URL = "https://api.mistral.ai"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; - private WebClient webClient; + private final WebClient webClient; + + private final MistralAiStreamFunctionCallingHelper chunkMerger = new MistralAiStreamFunctionCallingHelper(); /** * Create a new client api with DEFAULT_BASE_URL @@ -112,6 +115,201 @@ public class MistralAiApi { this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); } + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + // The input must not an empty string, and any array must be 1024 dimensions or + // less. + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 1024, "The list must be 1024 dimensions or less"); + Assert.isTrue( + list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri("/v1/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(mono1); + }) + .flatMap(mono -> mono); + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + // @formatter:off + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("model_length") MODEL_LENGTH, + /** + * + */ + @JsonProperty("error") ERROR, + /** + * The model requested a tool call. + */ + @JsonProperty("tool_calls") TOOL_CALLS + // @formatter:on + + } + + /** + * List of well-known Mistral chat models. + * https://docs.mistral.ai/platform/endpoints/#mistral-ai-generative-models + * + *

+ * Mistral AI provides two types of models: open-weights models (Mistral 7B, Mixtral + * 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral + * Medium, Mistral Large, and Mistral Embeddings). + */ + public enum ChatModel implements ChatModelDescription { + + // @formatter:off + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B + TINY("open-mistral-7b"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B + MIXTRAL("open-mixtral-8x7b"), + OPEN_MISTRAL_7B("open-mistral-7b"), + OPEN_MIXTRAL_7B("open-mixtral-8x7b"), + OPEN_MIXTRAL_22B("open-mixtral-8x22b"), + SMALL("mistral-small-latest"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model + MEDIUM("mistral-medium-latest"), + LARGE("mistral-large-latest"); + // @formatter:on + + private final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + + /** + * List of well-known Mistral embedding models. + * https://docs.mistral.ai/platform/endpoints/#mistral-ai-embedding-model + */ + public enum EmbeddingModel { + + // @formatter:off + EMBED("mistral-embed"); + // @formatter:on + + private final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + /** * Represents a tool the model may call. Currently, only functions are supported as a * tool. @@ -168,7 +366,9 @@ public class MistralAiApi { public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); } + } + } /** @@ -218,26 +418,29 @@ public class MistralAiApi { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Embedding embedding1)) + } + if (!(o instanceof Embedding embedding1)) { return false; - return Objects.equals(index, embedding1.index) && Arrays.equals(embedding, embedding1.embedding) - && Objects.equals(object, embedding1.object); + } + return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) + && Objects.equals(this.object, embedding1.object); } @Override public int hashCode() { - int result = Objects.hash(index, object); - result = 31 * result + Arrays.hashCode(embedding); + int result = Objects.hash(this.index, this.object); + result = 31 * result + Arrays.hashCode(this.embedding); return result; } @Override public String toString() { - return "Embedding{" + "index=" + index + ", embedding=" + Arrays.toString(embedding) + ", object='" + object - + '\'' + '}'; + return "Embedding{" + "index=" + this.index + ", embedding=" + Arrays.toString(this.embedding) + + ", object='" + this.object + '\'' + '}'; } + } /** @@ -274,6 +477,7 @@ public class MistralAiApi { public EmbeddingRequest(T input) { this(input, EmbeddingModel.EMBED.getValue()); } + } /** @@ -295,46 +499,6 @@ public class MistralAiApi { // @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or - * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single - * request, You can pass a {@link List} of {@link String} or {@link List} of - * {@link List} of tokens. For example: - * - *

{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - // The input must not an empty string, and any array must be 1024 dimensions or - // less. - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 1024, "The list must be 1024 dimensions or less"); - Assert.isTrue( - list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri("/v1/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - /** * Creates a model request for chat conversation. * @@ -472,7 +636,9 @@ public class MistralAiApi { */ @JsonInclude(Include.NON_NULL) public record ResponseFormat(@JsonProperty("type") String type) { + } + } /** @@ -547,6 +713,7 @@ public class MistralAiApi { @JsonInclude(Include.NON_NULL) public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { + } /** @@ -559,36 +726,8 @@ public class MistralAiApi { @JsonInclude(Include.NON_NULL) public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { + } - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - - // @formatter:off - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("model_length") MODEL_LENGTH, - /** - * - */ - @JsonProperty("error") ERROR, - /** - * The model requested a tool call. - */ - @JsonProperty("tool_calls") TOOL_CALLS - // @formatter:on } @@ -632,6 +771,7 @@ public class MistralAiApi { @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } + } /** @@ -676,8 +816,11 @@ public class MistralAiApi { @JsonInclude(Include.NON_NULL) public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List probBytes) { + } + } + } /** @@ -719,132 +862,7 @@ public class MistralAiApi { @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } - } - /** - * List of well-known Mistral chat models. - * https://docs.mistral.ai/platform/endpoints/#mistral-ai-generative-models - * - *

- * Mistral AI provides two types of models: open-weights models (Mistral 7B, Mixtral - * 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral - * Medium, Mistral Large, and Mistral Embeddings). - */ - public enum ChatModel implements ChatModelDescription { - - // @formatter:off - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B - TINY("open-mistral-7b"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B - MIXTRAL("open-mixtral-8x7b"), - OPEN_MISTRAL_7B("open-mistral-7b"), - OPEN_MIXTRAL_7B("open-mixtral-8x7b"), - OPEN_MIXTRAL_22B("open-mixtral-8x22b"), - SMALL("mistral-small-latest"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model - MEDIUM("mistral-medium-latest"), - LARGE("mistral-large-latest"); - // @formatter:on - - private final String value; - - ChatModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; - } - - } - - /** - * List of well-known Mistral embedding models. - * https://docs.mistral.ai/platform/endpoints/#mistral-ai-embedding-model - */ - public enum EmbeddingModel { - - // @formatter:off - EMBED("mistral-embed"); - // @formatter:on - - private final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private MistralAiStreamFunctionCallingHelper chunkMerger = new MistralAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v1/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(mono1); - }) - .flatMap(mono -> mono); } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java index 774bd0729..c00249eea 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api; import java.util.ArrayList; @@ -105,7 +106,7 @@ public class MistralAiStreamFunctionCallingHelper { private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { String content = (current.content() != null ? current.content() - : "" + ((previous.content() != null) ? previous.content() : "")); + : (previous.content() != null) ? previous.content() : ""); Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); @@ -198,4 +199,4 @@ public class MistralAiStreamFunctionCallingHelper { } } -// --- \ No newline at end of file +// --- diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java index c89982349..dbcc9a9d4 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.mistralai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -13,10 +29,6 @@ import org.springframework.util.Assert; */ public class MistralAiUsage implements Usage { - public static MistralAiUsage from(MistralAiApi.Usage usage) { - return new MistralAiUsage(usage); - } - private final MistralAiApi.Usage usage; protected MistralAiUsage(MistralAiApi.Usage usage) { @@ -24,6 +36,10 @@ public class MistralAiUsage implements Usage { this.usage = usage; } + public static MistralAiUsage from(MistralAiApi.Usage usage) { + return new MistralAiUsage(usage); + } + protected MistralAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 618ca25a6..26b329137 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.Arrays; @@ -27,11 +28,8 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -57,14 +55,11 @@ class MistralAiChatClientIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() - .system(s -> s.text(systemTextResource) + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -81,8 +76,8 @@ class MistralAiChatClientIT { void testMessageHistory() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() - .system(s -> s.text(systemTextResource) + ChatResponse response = ChatClient.create(this.chatModel).prompt() + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -92,7 +87,7 @@ class MistralAiChatClientIT { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); // @formatter:off - response = ChatClient.create(chatModel).prompt() + response = ChatClient.create(this.chatModel).prompt() .messages(List.of(new UserMessage("Dummy"), response.getResult().getOutput())) .user("Repeat the last assistant message.") .call() @@ -107,7 +102,7 @@ class MistralAiChatClientIT { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -122,7 +117,7 @@ class MistralAiChatClientIT { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -139,7 +134,7 @@ class MistralAiChatClientIT { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List 10 {subject}") .param("subject", "ice cream flavors")) .call() @@ -154,7 +149,7 @@ class MistralAiChatClientIT { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -169,7 +164,7 @@ class MistralAiChatClientIT { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -183,7 +178,7 @@ class MistralAiChatClientIT { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -200,7 +195,7 @@ class MistralAiChatClientIT { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .user(u -> u .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() @@ -226,7 +221,7 @@ class MistralAiChatClientIT { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build()) .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) @@ -245,7 +240,7 @@ class MistralAiChatClientIT { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) @@ -264,7 +259,7 @@ class MistralAiChatClientIT { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) @@ -284,7 +279,7 @@ class MistralAiChatClientIT { void validateCallResponseMetadata() { String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -299,4 +294,8 @@ class MistralAiChatClientIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilms(String actor, List movies) { + + } + } \ No newline at end of file diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index 29ffdb75a..d4efe0660 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import org.junit.jupiter.api.Test; @@ -37,7 +38,7 @@ public class MistralAiChatCompletionRequestTest { @Test void chatCompletionDefaultRequestTest() { - var request = chatModel.createRequest(new Prompt("test content"), false); + var request = this.chatModel.createRequest(new Prompt("test content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.topP()).isEqualTo(1); @@ -52,7 +53,7 @@ public class MistralAiChatCompletionRequestTest { var options = MistralAiChatOptions.builder().withTemperature(0.5).withTopP(0.8).build(); - var request = chatModel.createRequest(new Prompt("test content", options), true); + var request = this.chatModel.createRequest(new Prompt("test content", options), true); assertThat(request.messages().size()).isEqualTo(1); assertThat(request.topP()).isEqualTo(0.8); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 0ec22331d..9e3b28230 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.ArrayList; @@ -27,13 +28,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -66,9 +67,6 @@ class MistralAiChatModelIT { @Autowired protected StreamingChatModel streamingChatModel; - @Value("classpath:/prompts/system-message.st") - private Resource systemResource; - @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") protected Resource qaEvaluatorAccurateAnswerResource; @@ -81,16 +79,19 @@ class MistralAiChatModelIT { @Value("classpath:/prompts/eval/user-evaluator-message.st") protected Resource userEvaluatorResource; + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + @Test void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); // NOTE: Mistral expects the system message to be before the user message or will // fail with 400 error. Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -126,16 +127,13 @@ class MistralAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -148,7 +146,7 @@ class MistralAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -169,7 +167,7 @@ class MistralAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -202,7 +200,7 @@ class MistralAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -225,7 +223,7 @@ class MistralAiChatModelIT { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -240,4 +238,8 @@ class MistralAiChatModelIT { assertThat(content).containsAnyOf("10.0", "10"); } + record ActorsFilmsRecord(String actor, List movies) { + + } + } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index 4631ab807..a6a42311d 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.common.KeyValue; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class MistralAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -76,7 +78,7 @@ public class MistralAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -97,7 +99,7 @@ public class MistralAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -118,7 +120,7 @@ public class MistralAiChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java index de378c57a..b9c91cca8 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.List; @@ -35,21 +36,21 @@ class MistralAiEmbeddingIT { @Test void defaultEmbedding() { - assertThat(mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.mistralAiEmbeddingModel).isNotNull(); + var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); - assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); } @Test void embeddingTest() { - assertThat(mistralAiEmbeddingModel).isNotNull(); - var embeddingResponse = mistralAiEmbeddingModel.call(new EmbeddingRequest( + assertThat(this.mistralAiEmbeddingModel).isNotNull(); + var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest( List.of("Hello World", "World is big"), MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build())); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -58,7 +59,7 @@ class MistralAiEmbeddingIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9); - assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java index 55f634c0c..813c62bbc 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -33,8 +37,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -63,13 +65,13 @@ public class MistralAiEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 28b29fcd3..1818d2e43 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.mistralai; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +package org.springframework.ai.mistralai; import java.util.List; import java.util.Optional; @@ -29,6 +25,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -49,7 +47,10 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -59,25 +60,6 @@ import reactor.core.publisher.Flux; @ExtendWith(MockitoExtension.class) public class MistralAiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -90,21 +72,21 @@ public class MistralAiRetryTests { @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new MistralAiChatModel(mistralAiApi, + this.chatModel = new MistralAiChatModel(this.mistralAiApi, MistralAiChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) .withSafePrompt(false) .withModel(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) .build(), - null, retryTemplate); - embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, MetadataMode.EMBED, + null, this.retryTemplate); + this.embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MetadataMode.EMBED, MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), - retryTemplate); + this.retryTemplate); } @Test @@ -112,27 +94,27 @@ public class MistralAiRetryTests { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); - ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", + ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MistralAiApi.Usage(10, 10, 10)); - when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void mistralAiChatNonTransientError() { - when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -141,28 +123,28 @@ public class MistralAiRetryTests { var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); - ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L, "model", List.of(choice)); - when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test @Disabled("Currently stream() does not implement retry") public void mistralAiChatStreamNonTransientError() { - when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text"))); } @Test @@ -171,26 +153,45 @@ public class MistralAiRetryTests { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new MistralAiApi.Usage(10, 10, 10)); - when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void mistralAiEmbeddingNonTransientError() { - when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index b48373f9f..608eccca5 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java index 4a6f594a9..0c7c4dacf 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai; import java.util.function.Function; @@ -28,14 +29,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -57,34 +65,29 @@ public class MockWeatherService implements Function response = mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -62,7 +63,7 @@ public class MistralAiApiIT { You should reply to the user's request with your name and also in the style of a pirate. """, Role.SYSTEM); - ResponseEntity response = mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( List.of(systemMessage, userMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -72,7 +73,7 @@ public class MistralAiApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = mistralAiApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, true)); assertThat(response).isNotNull(); @@ -81,7 +82,7 @@ public class MistralAiApiIT { @Test void embeddings() { - ResponseEntity> response = mistralAiApi + ResponseEntity> response = this.mistralAiApi .embeddings(new MistralAiApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java index 4d23255a6..b4ea08ed0 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api.tool; import java.util.ArrayList; @@ -31,8 +32,8 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; -import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool.Type; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; @@ -46,14 +47,23 @@ import static org.assertj.core.api.Assertions.assertThat; @Disabled public class MistralAiApiToolFunctionCallIT { + static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.LARGE.getValue(); + private final Logger logger = LoggerFactory.getLogger(MistralAiApiToolFunctionCallIT.class); MockWeatherService weatherService = new MockWeatherService(); - static final String MISTRAL_AI_CHAT_MODEL = MistralAiApi.ChatModel.LARGE.getValue(); - MistralAiApi completionApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @Test @SuppressWarnings("null") public void toolFunctionCall() throws JsonProcessingException { @@ -100,7 +110,7 @@ public class MistralAiApiToolFunctionCallIT { System.out .println(new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(chatCompletionRequest)); - ResponseEntity chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.completionApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -123,7 +133,7 @@ public class MistralAiApiToolFunctionCallIT { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), @@ -133,10 +143,10 @@ public class MistralAiApiToolFunctionCallIT { var functionResponseRequest = new ChatCompletionRequest(messages, MISTRAL_AI_CHAT_MODEL, 0.8); - ResponseEntity chatCompletion2 = completionApi + ResponseEntity chatCompletion2 = this.completionApi .chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); @@ -145,21 +155,10 @@ public class MistralAiApiToolFunctionCallIT { .containsAnyOf("30.0°C", "30°C"); assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Tokyo") .containsAnyOf("10.0°C", "10°C"); - ; assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Paris") .containsAnyOf("15.0°C", "15°C"); - ; } } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java index 1c7c0d4de..c468dffbc 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.mistralai.api.tool; import java.util.function.Function; @@ -28,14 +29,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -57,34 +65,29 @@ public class MockWeatherService implements Function DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { - } - - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } - - private static class RetrievePaymentStatus implements Function { - - @Override - public Status apply(Transaction paymentTransaction) { - return new Status(DATA.get(paymentTransaction.transactionId).status); - } - - } - - private static class RetrievePaymentDate implements Function { - - @Override - public Date apply(Transaction paymentTransaction) { - return new Date(DATA.get(paymentTransaction.transactionId).date); - } - - } - static Map> functions = Map.of("retrieve_payment_status", new RetrievePaymentStatus(), "retrieve_payment_date", new RetrievePaymentDate()); + private final Logger logger = LoggerFactory.getLogger(PaymentStatusFunctionCallingIT.class); + + private static T jsonToObject(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @Test @SuppressWarnings("null") public void toolFunctionCall() throws JsonProcessingException { @@ -157,19 +136,44 @@ public class PaymentStatusFunctionCallingIT { .chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.LARGE.getValue())); var responseContent = response.getBody().choices().get(0).message().content(); - logger.info("Final response: " + responseContent); + this.logger.info("Final response: " + responseContent); assertThat(responseContent).containsIgnoringCase("T1001"); assertThat(responseContent).containsIgnoringCase("Paid"); } - private static T jsonToObject(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); + record StatusDate(String status, String date) { + + } + + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + + private static class RetrievePaymentStatus implements Function { + + @Override + public Status apply(Transaction paymentTransaction) { + return new Status(DATA.get(paymentTransaction.transactionId).status); } - catch (JsonProcessingException e) { - throw new RuntimeException(e); + + } + + private static class RetrievePaymentDate implements Function { + + @Override + public Date apply(Transaction paymentTransaction) { + return new Date(DATA.get(paymentTransaction.transactionId).date); } + } } diff --git a/models/spring-ai-moonshot/pom.xml b/models/spring-ai-moonshot/pom.xml index 6f7f92b52..84a154850 100644 --- a/models/spring-ai-moonshot/pom.xml +++ b/models/spring-ai-moonshot/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index bc5c3d13d..2751a9d64 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -60,14 +70,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; /** * @author Geng Rong @@ -158,6 +160,21 @@ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatMo this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -305,21 +322,6 @@ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatMo .build(); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", - toolCall.function().name(), toolCall.function().arguments())) - .toList(); - - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index b5dd81097..e5bae8560 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.moonshot; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.moonshot.api.MoonshotApi; -import org.springframework.boot.context.properties.NestedConfigurationProperty; -import org.springframework.util.Assert; +package org.springframework.ai.moonshot; import java.util.ArrayList; import java.util.HashSet; @@ -31,6 +22,17 @@ import java.util.List; import java.util.Map; import java.util.Set; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.moonshot.api.MoonshotApi; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + /** * @author Geng Rong * @author Thomas Vitale @@ -145,6 +147,10 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Map toolContext; + public static Builder builder() { + return new Builder(); + } + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -157,122 +163,13 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions @Override public Set getFunctions() { - return functions; + return this.functions; } public void setFunctions(Set functionNames) { this.functions = functionNames; } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - protected MoonshotChatOptions options; - - public Builder() { - this.options = new MoonshotChatOptions(); - } - - public Builder(MoonshotChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withN(Integer n) { - this.options.n = n; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withUser(String user) { - this.options.user = user; - return this; - } - - public Builder withTools(List tools) { - this.options.tools = tools; - return this; - } - - public Builder withToolChoice(String toolChoice) { - this.options.toolChoice = toolChoice; - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withProxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public MoonshotChatOptions build() { - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -411,93 +308,220 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } MoonshotChatOptions other = (MoonshotChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) + if (other.frequencyPenalty != null) { return false; + } } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.n == null) { - if (other.n != null) + if (other.n != null) { return false; + } } - else if (!this.n.equals(other.n)) + else if (!this.n.equals(other.n)) { return false; + } if (this.presencePenalty == null) { - if (other.presencePenalty != null) + if (other.presencePenalty != null) { return false; + } } - else if (!this.presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } if (this.user == null) { return other.user == null; } - else if (!this.user.equals(other.user)) + else if (!this.user.equals(other.user)) { return false; + } if (this.proxyToolCalls == null) { return other.proxyToolCalls == null; } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { return false; + } if (this.toolContext == null) { return other.toolContext == null; } - else if (!this.toolContext.equals(other.toolContext)) + else if (!this.toolContext.equals(other.toolContext)) { return false; + } return true; } + public static class Builder { + + protected MoonshotChatOptions options; + + public Builder() { + this.options = new MoonshotChatOptions(); + } + + public Builder(MoonshotChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public MoonshotChatOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java index 0ae4fccfe..7f8a3a27b 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.aot; import org.springframework.ai.moonshot.api.MoonshotApi; diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index 43050b025..f6eb1c476 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.function.Predicate; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -29,14 +39,6 @@ import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import java.util.function.Predicate; import static org.springframework.ai.moonshot.api.MoonshotConstants.DEFAULT_BASE_URL; @@ -102,6 +104,147 @@ public class MoonshotApi { this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); } + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v1/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + // Detect is the chunk is part of a streaming function call. + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + // Group all chunks belonging to the same function call. + // Flux -> Flux> + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + // Reduce the inner Flux window into a single + // Mono, + // Flux> -> Flux> + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + // Flux> -> Flux + .flatMap(mono -> mono); + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") + CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") + FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") + TOOL_CALL + + } + + /** + * Moonshot Chat Completion Models: + * + *

    + *
  • MOONSHOT_V1_8K - moonshot-v1-8k
  • + *
  • MOONSHOT_V1_32K - moonshot-v1-32k
  • + *
  • MOONSHOT_V1_128K - moonshot-v1-128k
  • + *
+ */ + public enum ChatModel implements ChatModelDescription { + + // @formatter:off + MOONSHOT_V1_8K("moonshot-v1-8k"), + MOONSHOT_V1_32K("moonshot-v1-32k"), + MOONSHOT_V1_128K("moonshot-v1-128k"); + // @formatter:on + + private final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + /** * Usage statistics. * @@ -252,6 +395,7 @@ public class MoonshotApi { } } + } /** @@ -275,6 +419,16 @@ public class MoonshotApi { // @formatter:on ) { + /** + * Create a chat completion message with the given content and role. All other + * fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -288,16 +442,6 @@ public class MoonshotApi { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other - * fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. NOTE: Moonshot expects the system * message to be before the user message or will fail with 400 error. @@ -340,6 +484,7 @@ public class MoonshotApi { @JsonInclude(Include.NON_NULL) public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { + } /** @@ -352,44 +497,8 @@ public class MoonshotApi { @JsonInclude(Include.NON_NULL) public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { + } - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") - STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") - LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") - CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") - TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") - FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") - TOOL_CALL } @@ -431,6 +540,7 @@ public class MoonshotApi { @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { // @formatter:on } + } /** @@ -471,39 +581,7 @@ public class MoonshotApi { @JsonProperty("usage") Usage usage // @formatter:on ) { - } - } - /** - * Moonshot Chat Completion Models: - * - *
    - *
  • MOONSHOT_V1_8K - moonshot-v1-8k
  • - *
  • MOONSHOT_V1_32K - moonshot-v1-32k
  • - *
  • MOONSHOT_V1_128K - moonshot-v1-128k
  • - *
- */ - public enum ChatModel implements ChatModelDescription { - - // @formatter:off - MOONSHOT_V1_8K("moonshot-v1-8k"), - MOONSHOT_V1_32K("moonshot-v1-32k"), - MOONSHOT_V1_128K("moonshot-v1-128k"); - // @formatter:on - - private final String value; - - ChatModel(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; } } @@ -564,76 +642,9 @@ public class MoonshotApi { public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); } + } - } - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v1/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - // cancels the flux stream after the "[DONE]" is received. - .takeUntil(SSE_DONE_PREDICATE) - // filters out the "[DONE]" message. - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - // Detect is the chunk is part of a streaming function call. - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - // Group all chunks belonging to the same function call. - // Flux -> Flux> - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - // Reduce the inner Flux window into a single - // Mono, - // Flux> -> Flux> - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - // Flux> -> Flux - .flatMap(mono -> mono); } } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java index c2aea6c05..3d6bdd4b2 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java index 5afff8216..06f1dc765 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java @@ -1,5 +1,24 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moonshot.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionFinishReason; @@ -9,9 +28,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Rol import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java index 5d5fadb17..3fb67358a 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moonshot.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -11,15 +27,15 @@ public class MoonshotUsage implements Usage { private final MoonshotApi.Usage usage; - public static MoonshotUsage from(MoonshotApi.Usage usage) { - return new MoonshotUsage(usage); - } - protected MoonshotUsage(MoonshotApi.Usage usage) { Assert.notNull(usage, "Moonshot Usage must not be null"); this.usage = usage; } + public static MoonshotUsage from(MoonshotApi.Usage usage) { + return new MoonshotUsage(usage); + } + protected MoonshotApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java index 89751f8c0..89b478122 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatCompletionRequestTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.boot.test.context.SpringBootTest; @@ -34,7 +36,7 @@ public class MoonshotChatCompletionRequestTest { @Test void chatCompletionDefaultRequestTest() { - var request = chatModel.createRequest(new Prompt("test content"), false); + var request = this.chatModel.createRequest(new Prompt("test content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.topP()).isEqualTo(1); @@ -46,7 +48,7 @@ public class MoonshotChatCompletionRequestTest { @Test void chatCompletionRequestWithOptionsTest() { var options = MoonshotChatOptions.builder().withTemperature(0.5).withTopP(0.8).build(); - var request = chatModel.createRequest(new Prompt("test content", options), true); + var request = this.chatModel.createRequest(new Prompt("test content", options), true); assertThat(request.messages().size()).isEqualTo(1); assertThat(request.topP()).isEqualTo(0.8); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java index 4e1df77fa..e87a11227 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion; @@ -35,10 +41,6 @@ import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -52,25 +54,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class MoonshotRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private @Mock MoonshotApi moonshotApi; @@ -80,10 +63,10 @@ public class MoonshotRetryTests { @BeforeEach public void beforeEach() { RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); - chatModel = new MoonshotChatModel(moonshotApi, + this.chatModel = new MoonshotChatModel(this.moonshotApi, MoonshotChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) @@ -100,24 +83,24 @@ public class MoonshotRetryTests { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", List.of(choice), new MoonshotApi.Usage(10, 10, 10)); - when(moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void moonshotChatNonTransientError() { - when(moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -128,24 +111,43 @@ public class MoonshotRetryTests { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, "model", List.of(choice)); - when(moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void moonshotChatStreamNonTransientError() { - when(moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java index 11db99be9..60a910769 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot; import org.springframework.ai.moonshot.api.MoonshotApi; diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java index e6015951d..60bb11f08 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java index 6c3619fdb..402409649 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ public class MockWeatherService implements Function response = moonshotApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -57,7 +59,7 @@ public class MoonshotApiIT { You should reply to the user's request with your name and also in the style of a pirate. """, Role.SYSTEM); - ResponseEntity response = moonshotApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest( List.of(systemMessage, userMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false)); assertThat(response).isNotNull(); @@ -67,7 +69,7 @@ public class MoonshotApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = moonshotApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.moonshotApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, true)); assertThat(response).isNotNull(); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java index 7c3764afe..fa2cc4863 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.moonshot.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage; import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role; @@ -32,10 +37,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool; import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -44,12 +45,6 @@ import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+") public class MoonshotApiToolFunctionCallIT { - private final Logger logger = LoggerFactory.getLogger(MoonshotApiToolFunctionCallIT.class); - - private final MockWeatherService weatherService = new MockWeatherService(); - - private final MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); - private static final FunctionTool FUNCTION_TOOL = new FunctionTool(Type.FUNCTION, new FunctionTool.Function( "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ { @@ -76,6 +71,21 @@ public class MoonshotApiToolFunctionCallIT { } """)); + private final Logger logger = LoggerFactory.getLogger(MoonshotApiToolFunctionCallIT.class); + + private final MockWeatherService weatherService = new MockWeatherService(); + + private final MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -97,7 +107,7 @@ public class MoonshotApiToolFunctionCallIT { ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), List.of(FUNCTION_TOOL), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = moonshotApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.moonshotApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -116,7 +126,7 @@ public class MoonshotApiToolFunctionCallIT { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -127,9 +137,9 @@ public class MoonshotApiToolFunctionCallIT { var functionResponseRequest = new ChatCompletionRequest(messages, MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.5); - ResponseEntity chatCompletion2 = moonshotApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.moonshotApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -138,13 +148,4 @@ public class MoonshotApiToolFunctionCallIT { .containsAnyOf("30"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java index d4436cbb7..c0ec33e70 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; import java.util.List; @@ -30,7 +31,7 @@ public class ActorsFilms { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -38,7 +39,7 @@ public class ActorsFilms { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -47,7 +48,7 @@ public class ActorsFilms { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java index 91a4bc9b3..8fa54687b 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -33,12 +41,6 @@ import org.springframework.ai.moonshot.api.MockWeatherService; import org.springframework.ai.moonshot.api.MoonshotApi; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -68,7 +70,7 @@ class MoonshotChatModelFunctionCallingIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -92,7 +94,7 @@ class MoonshotChatModelFunctionCallingIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -108,4 +110,4 @@ class MoonshotChatModelFunctionCallingIT { assertThat(content).contains("30", "10", "15"); } -} \ No newline at end of file +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java index f0b4b7944..83222c75d 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -39,11 +46,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -68,10 +70,10 @@ public class MoonshotChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -114,7 +116,7 @@ public class MoonshotChatModelIT { "numbers": [1, 2, 3] }""", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -133,15 +135,12 @@ public class MoonshotChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -161,7 +160,7 @@ public class MoonshotChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -184,7 +183,7 @@ public class MoonshotChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -200,4 +199,8 @@ public class MoonshotChatModelIT { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java index 9e4239a5c..f13a46f4a 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.moonshot.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class MoonshotChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -79,7 +81,7 @@ public class MoonshotChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -102,7 +104,7 @@ public class MoonshotChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -123,7 +125,7 @@ public class MoonshotChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-oci-genai/pom.xml b/models/spring-ai-oci-genai/pom.xml index b4ab01cfe..64a474969 100644 --- a/models/spring-ai-oci-genai/pom.xml +++ b/models/spring-ai-oci-genai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java index 123e0705f..e3658a226 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.util.ArrayList; @@ -28,6 +29,7 @@ import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; import com.oracle.bmc.generativeaiinference.model.ServingMode; import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest; import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -83,7 +85,7 @@ public class OCIEmbeddingModel extends AbstractEmbeddingModel { @Override public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), options); + OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), this.options); List embedTextRequests = createRequests(request.getInstructions(), runtimeOptions); EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder() @@ -109,7 +111,7 @@ public class OCIEmbeddingModel extends AbstractEmbeddingModel { AtomicInteger index = new AtomicInteger(0); List embeddings = new ArrayList<>(); for (EmbedTextRequest embedTextRequest : embedTextRequests) { - EmbedTextResult embedTextResult = genAi.embedText(embedTextRequest).getEmbedTextResult(); + EmbedTextResult embedTextResult = this.genAi.embedText(embedTextRequest).getEmbedTextResult(); if (modelId == null) { modelId = embedTextResult.getModelId(); } diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java index e72f5359f..3f5641a6e 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -40,6 +42,47 @@ public class OCIEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + /** + * Not used by OCI GenAI. + * @return null + */ + @Override + public Integer getDimensions() { + return null; + } + + public String getCompartment() { + return this.compartment; + } + + public void setCompartment(String compartment) { + this.compartment = compartment; + } + + public String getServingMode() { + return this.servingMode; + } + + public void setServingMode(String servingMode) { + this.servingMode = servingMode; + } + + public EmbedTextDetails.Truncate getTruncate() { + return this.truncate; + } + + public void setTruncate(EmbedTextDetails.Truncate truncate) { + this.truncate = truncate; + } + public static class Builder { private final OCIEmbeddingOptions options = new OCIEmbeddingOptions(); @@ -70,45 +113,4 @@ public class OCIEmbeddingOptions implements EmbeddingOptions { } - public String getModel() { - return this.model; - } - - /** - * Not used by OCI GenAI. - * @return null - */ - @Override - public Integer getDimensions() { - return null; - } - - public void setModel(String model) { - this.model = model; - } - - public String getCompartment() { - return compartment; - } - - public void setCompartment(String compartment) { - this.compartment = compartment; - } - - public String getServingMode() { - return servingMode; - } - - public void setServingMode(String servingMode) { - this.servingMode = servingMode; - } - - public EmbedTextDetails.Truncate getTruncate() { - return truncate; - } - - public void setTruncate(EmbedTextDetails.Truncate truncate) { - this.truncate = truncate; - } - } diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java index 5124bd734..b1f6da89b 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.io.IOException; diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java index 8d25240bf..586fbfdde 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.oci; import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.oci.BaseEmbeddingModelTest.OCI_COMPARTMENT_ID_KEY; -@EnabledIfEnvironmentVariable(named = OCI_COMPARTMENT_ID_KEY, matches = ".+") +@EnabledIfEnvironmentVariable(named = org.springframework.ai.oci.BaseEmbeddingModelTest.OCI_COMPARTMENT_ID_KEY, + matches = ".+") public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { private final OCIEmbeddingModel embeddingModel = get(); @@ -35,13 +37,13 @@ public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { @Test void embed() { - float[] embedding = embeddingModel.embed(new Document("How many provinces are in Canada?")); + float[] embedding = this.embeddingModel.embed(new Document("How many provinces are in Canada?")); assertThat(embedding).hasSize(1024); } @Test void call() { - EmbeddingResponse response = embeddingModel.call(new EmbeddingRequest(content, null)); + EmbeddingResponse response = this.embeddingModel.call(new EmbeddingRequest(this.content, null)); assertThat(response).isNotNull(); assertThat(response.getResults()).hasSize(2); assertThat(response.getMetadata().getModel()).isEqualTo(EMBEDDING_MODEL_V2); @@ -49,8 +51,8 @@ public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { @Test void callWithOptions() { - EmbeddingResponse response = embeddingModel - .call(new EmbeddingRequest(content, OCIEmbeddingOptions.builder().withModel(EMBEDDING_MODEL_V3).build())); + EmbeddingResponse response = this.embeddingModel.call(new EmbeddingRequest(this.content, + OCIEmbeddingOptions.builder().withModel(EMBEDDING_MODEL_V3).build())); assertThat(response).isNotNull(); assertThat(response.getResults()).hasSize(2); assertThat(response.getMetadata().getModel()).isEqualTo(EMBEDDING_MODEL_V3); diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 69f64f8c9..3b9a4428a 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 93edabe8c..c60523c00 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import java.util.Base64; @@ -24,13 +25,19 @@ import java.util.Set; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.model.*; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; @@ -43,22 +50,20 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall; import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.ai.ollama.metadata.OllamaChatUsage; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run * large language models and generate embeddings locally. It supports open-source models @@ -96,7 +101,7 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; - this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions); + this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); } @@ -104,6 +109,22 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode return new Builder(); } + public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { + Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); + return ChatResponseMetadata.builder() + .withUsage(OllamaChatUsage.from(response)) + .withModel(response.model()) + .withKeyValue("created-at", response.createdAt()) + .withKeyValue("eval-duration", response.evalDuration()) + .withKeyValue("eval-count", response.evalCount()) + .withKeyValue("load-duration", response.loadDuration()) + .withKeyValue("eval-duration", response.promptEvalDuration()) + .withKeyValue("eval-count", response.promptEvalCount()) + .withKeyValue("total-duration", response.totalDuration()) + .withKeyValue("done", response.done()) + .build(); + } + @Override public ChatResponse call(Prompt prompt) { @@ -157,22 +178,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode return response; } - public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - return ChatResponseMetadata.builder() - .withUsage(OllamaChatUsage.from(response)) - .withModel(response.model()) - .withKeyValue("created-at", response.createdAt()) - .withKeyValue("eval-duration", response.evalDuration()) - .withKeyValue("eval-count", response.evalCount()) - .withKeyValue("load-duration", response.loadDuration()) - .withKeyValue("eval-duration", response.promptEvalDuration()) - .withKeyValue("eval-count", response.promptEvalCount()) - .withKeyValue("total-duration", response.totalDuration()) - .withKeyValue("done", response.done()) - .build(); - } - @Override public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { @@ -435,10 +440,10 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode } public OllamaChatModel build() { - return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, - observationRegistry, modelManagementOptions); + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackContext, + this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions); } } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index 7034a9c03..f44c9c6ea 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import java.time.Duration; @@ -22,19 +23,27 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage; import org.springframework.util.Assert; @@ -236,9 +245,10 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel { } public OllamaEmbeddingModel build() { - return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions); + return new OllamaEmbeddingModel(this.ollamaApi, this.defaultOptions, this.observationRegistry, + this.modelManagementOptions); } } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java index 2d89a804e..bd8799c9b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.aot; import org.springframework.ai.ollama.api.OllamaApi; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index acd9028d1..bbd32c511 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import java.io.IOException; @@ -23,8 +24,14 @@ import java.util.Map; import java.util.Objects; import java.util.function.Consumer; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.boot.context.properties.bind.ConstructorBinding; @@ -39,13 +46,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Java Client for the Ollama API. https://ollama.ai * @@ -56,40 +56,20 @@ import reactor.core.publisher.Mono; // @formatter:off public class OllamaApi { - private static final Log logger = LogFactory.getLog(OllamaApi.class); - - private static final String DEFAULT_BASE_URL = "http://localhost:11434"; - public static final String PROVIDER_NAME = AiProvider.OLLAMA.value(); public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null."; + private static final Log logger = LogFactory.getLog(OllamaApi.class); + + private static final String DEFAULT_BASE_URL = "http://localhost:11434"; + private final ResponseErrorHandler responseErrorHandler; private final RestClient restClient; private final WebClient webClient; - private static class OllamaResponseErrorHandler implements ResponseErrorHandler { - - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return response.getStatusCode().isError(); - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - if (response.getStatusCode().isError()) { - int statusCode = response.getStatusCode().value(); - String statusText = response.getStatusText(); - String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); - logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); - throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); - } - } - - } - /** * Default constructor that uses the default localhost url. */ @@ -125,9 +105,223 @@ public class OllamaApi { this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); } + /** + * Generate a completion for the given prompt. + * @param completionRequest Completion request. + * @return Completion response. + * @deprecated Use {@link #chat(ChatRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public GenerateResponse generate(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/generate") + .body(completionRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(GenerateResponse.class); + } + // -------------------------------------------------------------------------- // Generate & Streaming Generate // -------------------------------------------------------------------------- + + /** + * Generate a streaming completion for the given prompt. + * @param completionRequest Completion request. The request must set the stream + * property to true. + * @return Completion response as a {@link Flux} stream. + * @deprecated Use {@link #streamingChat(ChatRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public Flux generateStreaming(GenerateRequest completionRequest) { + Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(completionRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/generate") + .body(Mono.just(completionRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(GenerateResponse.class) + .handle((data, sink) -> { + if (logger.isTraceEnabled()) { + logger.trace(data); + } + sink.next(data); + }); + } + + /** + * Generate the next message in a chat with a provided model. + * This is a streaming endpoint (controlled by the 'stream' request property), so + * there will be a series of responses. The final response object will include + * statistics and additional data from the request. + * @param chatRequest Chat request. + * @return Chat response. + */ + public ChatResponse chat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled."); + + return this.restClient.post() + .uri("/api/chat") + .body(chatRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ChatResponse.class); + } + + /** + * Streaming response for the chat completion request. + * @param chatRequest Chat request. The request must set the stream property to true. + * @return Chat response as a {@link Flux} stream. + */ + public Flux streamingChat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/chat") + .body(Mono.just(chatRequest), GenerateRequest.class) + .retrieve() + .bodyToFlux(ChatResponse.class) + .handle((data, sink) -> { + if (logger.isTraceEnabled()) { + logger.trace(data); + } + sink.next(data); + }); + } + + /** + * Generate embeddings from a model. + * @param embeddingsRequest Embedding request. + * @return Embeddings response. + */ + public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { + Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR); + + return this.restClient.post() + .uri("/api/embed") + .body(embeddingsRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(EmbeddingsResponse.class); + } + + // -------------------------------------------------------------------------- + // Chat & Streaming Chat + // -------------------------------------------------------------------------- + + /** + * Generate embeddings from a model. + * @param embeddingRequest Embedding request. + * @return Embedding response. + * @deprecated Use {@link #embed(EmbeddingsRequest)} instead. + */ + @Deprecated(since = "1.0.0-M2", forRemoval = true) + public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { + Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR); + + return this.restClient.post() + .uri("/api/embeddings") + .body(embeddingRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(EmbeddingResponse.class); + } + + /** + * List models that are available locally on the machine where Ollama is running. + */ + public ListModelResponse listModels() { + return this.restClient.get() + .uri("/api/tags") + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ListModelResponse.class); + } + + /** + * Show information about a model available locally on the machine where Ollama is running. + */ + public ShowModelResponse showModel(ShowModelRequest showModelRequest) { + Assert.notNull(showModelRequest, "showModelRequest must not be null"); + return this.restClient.post() + .uri("/api/show") + .body(showModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .body(ShowModelResponse.class); + } + + /** + * Copy a model. Creates a model with another name from an existing model. + */ + public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { + Assert.notNull(copyModelRequest, "copyModelRequest must not be null"); + return this.restClient.post() + .uri("/api/copy") + .body(copyModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + /** + * Delete a model and its data. + */ + public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { + Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null"); + return this.restClient.method(HttpMethod.DELETE) + .uri("/api/delete") + .body(deleteModelRequest) + .retrieve() + .onStatus(this.responseErrorHandler) + .toBodilessEntity(); + } + + // -------------------------------------------------------------------------- + // Embeddings + // -------------------------------------------------------------------------- + + /** + * Download a model from the Ollama library. Cancelled pulls are resumed from where they left off, + * and multiple calls will share the same download progress. + */ + public Flux pullModel(PullModelRequest pullModelRequest) { + Assert.notNull(pullModelRequest, "pullModelRequest must not be null"); + Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/pull") + .bodyValue(pullModelRequest) + .retrieve() + .bodyToFlux(ProgressResponse.class); + } + + private static class OllamaResponseErrorHandler implements ResponseErrorHandler { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + int statusCode = response.getStatusCode().value(); + String statusText = response.getStatusText(); + String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); + logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); + throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); + } + } + + } + /** * The request object sent to the /generate endpoint. * @@ -197,8 +391,10 @@ public class OllamaApi { public static class Builder { - private String model; private final String prompt; + + private String model; + private String format; private Map options; private String system; @@ -269,7 +465,7 @@ public class OllamaApi { } public GenerateRequest build() { - return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw, images, keepAlive); + return new GenerateRequest(this.model, this.prompt, this.format, this.options, this.system, this.template, this.context, this.stream, this.raw, this.images, this.keepAlive); } } @@ -312,53 +508,6 @@ public class OllamaApi { @JsonProperty("eval_duration") Duration evalDuration) { } - /** - * Generate a completion for the given prompt. - * @param completionRequest Completion request. - * @return Completion response. - * @deprecated Use {@link #chat(ChatRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public GenerateResponse generate(GenerateRequest completionRequest) { - Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); - - return this.restClient.post() - .uri("/api/generate") - .body(completionRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(GenerateResponse.class); - } - - /** - * Generate a streaming completion for the given prompt. - * @param completionRequest Completion request. The request must set the stream - * property to true. - * @return Completion response as a {@link Flux} stream. - * @deprecated Use {@link #streamingChat(ChatRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public Flux generateStreaming(GenerateRequest completionRequest) { - Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(completionRequest.stream(), "Request must set the stream property to true."); - - return webClient.post() - .uri("/api/generate") - .body(Mono.just(completionRequest), GenerateRequest.class) - .retrieve() - .bodyToFlux(GenerateResponse.class) - .handle((data, sink) -> { - if (logger.isTraceEnabled()) { - logger.trace(data); - } - sink.next(data); - }); - } - - // -------------------------------------------------------------------------- - // Chat & Streaming Chat - // -------------------------------------------------------------------------- /** * Chat message object. * @@ -374,6 +523,10 @@ public class OllamaApi { @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) { + public static Builder builder(Role role) { + return new Builder(role); + } + /** * The role of the message in the conversation. */ @@ -420,10 +573,6 @@ public class OllamaApi { @JsonProperty("arguments") Map arguments) { } - public static Builder builder(Role role) { - return new Builder(role); - } - public static class Builder { private final Role role; @@ -451,7 +600,7 @@ public class OllamaApi { } public Message build() { - return new Message(role, content, images, toolCalls); + return new Message(this.role, this.content, this.images, this.toolCalls); } } @@ -486,6 +635,10 @@ public class OllamaApi { @JsonProperty("options") Map options ) { + public static Builder builder(String model) { + return new Builder(model); + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -543,10 +696,6 @@ public class OllamaApi { } } } - - public static Builder builder(String model) { - return new Builder(model); - } public static class Builder { @@ -602,11 +751,15 @@ public class OllamaApi { } public ChatRequest build() { - return new ChatRequest(model, messages, stream, format, keepAlive, tools, options); + return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options); } } } + // -------------------------------------------------------------------------- + // Models + // -------------------------------------------------------------------------- + /** * Ollama chat response object. * @@ -647,51 +800,6 @@ public class OllamaApi { ) { } - /** - * Generate the next message in a chat with a provided model. - * This is a streaming endpoint (controlled by the 'stream' request property), so - * there will be a series of responses. The final response object will include - * statistics and additional data from the request. - * @param chatRequest Chat request. - * @return Chat response. - */ - public ChatResponse chat(ChatRequest chatRequest) { - Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled."); - - return this.restClient.post() - .uri("/api/chat") - .body(chatRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ChatResponse.class); - } - - /** - * Streaming response for the chat completion request. - * @param chatRequest Chat request. The request must set the stream property to true. - * @return Chat response as a {@link Flux} stream. - */ - public Flux streamingChat(ChatRequest chatRequest) { - Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - return webClient.post() - .uri("/api/chat") - .body(Mono.just(chatRequest), GenerateRequest.class) - .retrieve() - .bodyToFlux(ChatResponse.class) - .handle((data, sink) -> { - if (logger.isTraceEnabled()) { - logger.trace(data); - } - sink.next(data); - }); - } - - // -------------------------------------------------------------------------- - // Embeddings - // -------------------------------------------------------------------------- /** * Generate embeddings from a model. * @@ -718,7 +826,7 @@ public class OllamaApi { public EmbeddingsRequest(String model, String input) { this(model, List.of(input), null, null, null); } - } + } /** * Generate embeddings from a model. @@ -759,11 +867,10 @@ public class OllamaApi { @JsonProperty("embedding") List embedding) { } - /** * The response object returned from the /embedding endpoint. * @param model The model used for generating the embeddings. - * @param embeddings The list of embeddings generated from the model. + * @param embeddings The list of embeddings generated from the model. * Each embedding (list of doubles) corresponds to a single input text. */ @JsonInclude(Include.NON_NULL) @@ -773,46 +880,9 @@ public class OllamaApi { @JsonProperty("total_duration") Long totalDuration, @JsonProperty("load_duration") Long loadDuration, @JsonProperty("prompt_eval_count") Integer promptEvalCount) { - + } - /** - * Generate embeddings from a model. - * @param embeddingsRequest Embedding request. - * @return Embeddings response. - */ - public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { - Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR); - - return this.restClient.post() - .uri("/api/embed") - .body(embeddingsRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(EmbeddingsResponse.class); - } - /** - * Generate embeddings from a model. - * @param embeddingRequest Embedding request. - * @return Embedding response. - * @deprecated Use {@link #embed(EmbeddingsRequest)} instead. - */ - @Deprecated(since = "1.0.0-M2", forRemoval = true) - public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) { - Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR); - - return this.restClient.post() - .uri("/api/embeddings") - .body(embeddingRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(EmbeddingResponse.class); - } - - // -------------------------------------------------------------------------- - // Models - // -------------------------------------------------------------------------- - @JsonInclude(Include.NON_NULL) public record Model( @JsonProperty("name") String name, @@ -838,17 +908,6 @@ public class OllamaApi { @JsonProperty("models") List models ) {} - /** - * List models that are available locally on the machine where Ollama is running. - */ - public ListModelResponse listModels() { - return this.restClient.get() - .uri("/api/tags") - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ListModelResponse.class); - } - @JsonInclude(Include.NON_NULL) public record ShowModelRequest( @JsonProperty("model") String model, @@ -875,56 +934,17 @@ public class OllamaApi { @JsonProperty("modified_at") Instant modifiedAt ) {} - /** - * Show information about a model available locally on the machine where Ollama is running. - */ - public ShowModelResponse showModel(ShowModelRequest showModelRequest) { - Assert.notNull(showModelRequest, "showModelRequest must not be null"); - return this.restClient.post() - .uri("/api/show") - .body(showModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .body(ShowModelResponse.class); - } - @JsonInclude(Include.NON_NULL) public record CopyModelRequest( @JsonProperty("source") String source, @JsonProperty("destination") String destination ) {} - /** - * Copy a model. Creates a model with another name from an existing model. - */ - public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { - Assert.notNull(copyModelRequest, "copyModelRequest must not be null"); - return this.restClient.post() - .uri("/api/copy") - .body(copyModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .toBodilessEntity(); - } - @JsonInclude(Include.NON_NULL) public record DeleteModelRequest( @JsonProperty("model") String model ) {} - /** - * Delete a model and its data. - */ - public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { - Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null"); - return this.restClient.method(HttpMethod.DELETE) - .uri("/api/delete") - .body(deleteModelRequest) - .retrieve() - .onStatus(this.responseErrorHandler) - .toBodilessEntity(); - } - @JsonInclude(Include.NON_NULL) public record PullModelRequest( @JsonProperty("model") String model, @@ -953,20 +973,5 @@ public class OllamaApi { @JsonProperty("completed") Long completed ) {} - /** - * Download a model from the Ollama library. Cancelled pulls are resumed from where they left off, - * and multiple calls will share the same download progress. - */ - public Flux pullModel(PullModelRequest pullModelRequest) { - Assert.notNull(pullModelRequest, "pullModelRequest must not be null"); - Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true."); - - return this.webClient.post() - .uri("/api/pull") - .bodyValue(pullModelRequest) - .retrieve() - .bodyToFlux(ProgressResponse.class); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index a70765249..3419ee283 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import org.springframework.ai.model.ChatModelDescription; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a0dad31a4..034a4b75c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; import java.util.ArrayList; @@ -23,6 +24,11 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.model.ModelOptionsUtils; @@ -31,11 +37,6 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * Helper class for creating strongly-typed Ollama options. * @@ -306,10 +307,70 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed return new OllamaOptions(); } + /** + * Helper factory method to create a new {@link OllamaOptions} instance. + * @return A new {@link OllamaOptions} instance. + */ + public static OllamaOptions create() { + return new OllamaOptions(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static OllamaOptions fromOptions(OllamaOptions fromOptions) { + return new OllamaOptions() + .withModel(fromOptions.getModel()) + .withFormat(fromOptions.getFormat()) + .withKeepAlive(fromOptions.getKeepAlive()) + .withTruncate(fromOptions.getTruncate()) + .withUseNUMA(fromOptions.getUseNUMA()) + .withNumCtx(fromOptions.getNumCtx()) + .withNumBatch(fromOptions.getNumBatch()) + .withNumGPU(fromOptions.getNumGPU()) + .withMainGPU(fromOptions.getMainGPU()) + .withLowVRAM(fromOptions.getLowVRAM()) + .withF16KV(fromOptions.getF16KV()) + .withLogitsAll(fromOptions.getLogitsAll()) + .withVocabOnly(fromOptions.getVocabOnly()) + .withUseMMap(fromOptions.getUseMMap()) + .withUseMLock(fromOptions.getUseMLock()) + .withNumThread(fromOptions.getNumThread()) + .withNumKeep(fromOptions.getNumKeep()) + .withSeed(fromOptions.getSeed()) + .withNumPredict(fromOptions.getNumPredict()) + .withTopK(fromOptions.getTopK()) + .withTopP(fromOptions.getTopP()) + .withTfsZ(fromOptions.getTfsZ()) + .withTypicalP(fromOptions.getTypicalP()) + .withRepeatLastN(fromOptions.getRepeatLastN()) + .withTemperature(fromOptions.getTemperature()) + .withRepeatPenalty(fromOptions.getRepeatPenalty()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMirostat(fromOptions.getMirostat()) + .withMirostatTau(fromOptions.getMirostatTau()) + .withMirostatEta(fromOptions.getMirostatEta()) + .withPenalizeNewline(fromOptions.getPenalizeNewline()) + .withStop(fromOptions.getStop()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withToolContext(fromOptions.getToolContext()); + } + public OllamaOptions build() { return this; } - + /** * @param model The ollama model names to use. See the {@link OllamaModel} for the common models. */ @@ -510,7 +571,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed } else { this.toolContext.putAll(toolContext); - } + } return this; } @@ -519,7 +580,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed // ------------------- @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -811,7 +872,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed @Override public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.functionCallbacks; } @Override @@ -862,107 +923,51 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed return ModelOptionsUtils.objectToMap(this); } - /** - * Helper factory method to create a new {@link OllamaOptions} instance. - * @return A new {@link OllamaOptions} instance. - */ - public static OllamaOptions create() { - return new OllamaOptions(); - } - - /** - * Filter out the non-supported fields from the options. - * @param options The options to filter. - * @return The filtered options. - */ - public static Map filterNonSupportedFields(Map options) { - return options.entrySet().stream() - .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - @Override public OllamaOptions copy() { return fromOptions(this); } - - public static OllamaOptions fromOptions(OllamaOptions fromOptions) { - return new OllamaOptions() - .withModel(fromOptions.getModel()) - .withFormat(fromOptions.getFormat()) - .withKeepAlive(fromOptions.getKeepAlive()) - .withTruncate(fromOptions.getTruncate()) - .withUseNUMA(fromOptions.getUseNUMA()) - .withNumCtx(fromOptions.getNumCtx()) - .withNumBatch(fromOptions.getNumBatch()) - .withNumGPU(fromOptions.getNumGPU()) - .withMainGPU(fromOptions.getMainGPU()) - .withLowVRAM(fromOptions.getLowVRAM()) - .withF16KV(fromOptions.getF16KV()) - .withLogitsAll(fromOptions.getLogitsAll()) - .withVocabOnly(fromOptions.getVocabOnly()) - .withUseMMap(fromOptions.getUseMMap()) - .withUseMLock(fromOptions.getUseMLock()) - .withNumThread(fromOptions.getNumThread()) - .withNumKeep(fromOptions.getNumKeep()) - .withSeed(fromOptions.getSeed()) - .withNumPredict(fromOptions.getNumPredict()) - .withTopK(fromOptions.getTopK()) - .withTopP(fromOptions.getTopP()) - .withTfsZ(fromOptions.getTfsZ()) - .withTypicalP(fromOptions.getTypicalP()) - .withRepeatLastN(fromOptions.getRepeatLastN()) - .withTemperature(fromOptions.getTemperature()) - .withRepeatPenalty(fromOptions.getRepeatPenalty()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMirostat(fromOptions.getMirostat()) - .withMirostatTau(fromOptions.getMirostatTau()) - .withMirostatEta(fromOptions.getMirostatEta()) - .withPenalizeNewline(fromOptions.getPenalizeNewline()) - .withStop(fromOptions.getStop()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withToolContext(fromOptions.getToolContext()); - } // @formatter:on @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } OllamaOptions that = (OllamaOptions) o; - return Objects.equals(model, that.model) && Objects.equals(format, that.format) - && Objects.equals(keepAlive, that.keepAlive) && Objects.equals(truncate, that.truncate) - && Objects.equals(useNUMA, that.useNUMA) && Objects.equals(numCtx, that.numCtx) - && Objects.equals(numBatch, that.numBatch) && Objects.equals(numGPU, that.numGPU) - && Objects.equals(mainGPU, that.mainGPU) && Objects.equals(lowVRAM, that.lowVRAM) - && Objects.equals(f16KV, that.f16KV) && Objects.equals(logitsAll, that.logitsAll) - && Objects.equals(vocabOnly, that.vocabOnly) && Objects.equals(useMMap, that.useMMap) - && Objects.equals(useMLock, that.useMLock) && Objects.equals(numThread, that.numThread) - && Objects.equals(numKeep, that.numKeep) && Objects.equals(seed, that.seed) - && Objects.equals(numPredict, that.numPredict) && Objects.equals(topK, that.topK) - && Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ) - && Objects.equals(typicalP, that.typicalP) && Objects.equals(repeatLastN, that.repeatLastN) - && Objects.equals(temperature, that.temperature) && Objects.equals(repeatPenalty, that.repeatPenalty) - && Objects.equals(presencePenalty, that.presencePenalty) - && Objects.equals(frequencyPenalty, that.frequencyPenalty) && Objects.equals(mirostat, that.mirostat) - && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) - && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) - && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions) - && Objects.equals(toolContext, that.toolContext); + return Objects.equals(this.model, that.model) && Objects.equals(this.format, that.format) + && Objects.equals(this.keepAlive, that.keepAlive) && Objects.equals(this.truncate, that.truncate) + && Objects.equals(this.useNUMA, that.useNUMA) && Objects.equals(this.numCtx, that.numCtx) + && Objects.equals(this.numBatch, that.numBatch) && Objects.equals(this.numGPU, that.numGPU) + && Objects.equals(this.mainGPU, that.mainGPU) && Objects.equals(this.lowVRAM, that.lowVRAM) + && Objects.equals(this.f16KV, that.f16KV) && Objects.equals(this.logitsAll, that.logitsAll) + && Objects.equals(this.vocabOnly, that.vocabOnly) && Objects.equals(this.useMMap, that.useMMap) + && Objects.equals(this.useMLock, that.useMLock) && Objects.equals(this.numThread, that.numThread) + && Objects.equals(this.numKeep, that.numKeep) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.numPredict, that.numPredict) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP) && Objects.equals(this.tfsZ, that.tfsZ) + && Objects.equals(this.typicalP, that.typicalP) && Objects.equals(this.repeatLastN, that.repeatLastN) + && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.repeatPenalty, that.repeatPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.mirostatEta, that.mirostatEta) + && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.functions, that.functions) && Objects.equals(this.toolContext, that.toolContext); } @Override public int hashCode() { return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, + this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, + this.topP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls, this.toolContext); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java index 5d600b14e..f850cd579 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/ModelManagementOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.management; import java.time.Duration; @@ -66,7 +67,8 @@ public record ModelManagementOptions(PullModelStrategy pullModelStrategy, List m.name().equals(normalizedModelName)); @@ -79,17 +81,17 @@ public class OllamaModelManager { } public void deleteModel(String modelName) { - logger.info("Start deletion of model: {}", modelName); + this.logger.info("Start deletion of model: {}", modelName); if (!isModelAvailable(modelName)) { - logger.info("Model {} not found", modelName); + this.logger.info("Model {} not found", modelName); return; } this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)); - logger.info("Completed deletion of model: {}", modelName); + this.logger.info("Completed deletion of model: {}", modelName); } public void pullModel(String modelName) { - pullModel(modelName, options.pullModelStrategy()); + pullModel(modelName, this.options.pullModelStrategy()); } public void pullModel(String modelName, PullModelStrategy pullModelStrategy) { @@ -99,27 +101,27 @@ public class OllamaModelManager { if (PullModelStrategy.WHEN_MISSING.equals(pullModelStrategy)) { if (isModelAvailable(modelName)) { - logger.debug("Model '{}' already available. Skipping pull operation.", modelName); + this.logger.debug("Model '{}' already available. Skipping pull operation.", modelName); return; } } // @formatter:off - logger.info("Start pulling model: {}", modelName); + this.logger.info("Start pulling model: {}", modelName); this.ollamaApi.pullModel(new PullModelRequest(modelName)) .bufferUntilChanged(OllamaApi.ProgressResponse::status) .doOnEach(signal -> { var progressResponses = signal.get(); if (!CollectionUtils.isEmpty(progressResponses) && progressResponses.get(progressResponses.size() - 1) != null) { - logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); + this.logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); } }) .takeUntil(progressResponses -> progressResponses.get(0) != null && progressResponses.get(0).status().equals("success")) - .timeout(options.timeout()) - .retryWhen(Retry.backoff(options.maxRetries(), Duration.ofSeconds(5))) + .timeout(this.options.timeout()) + .retryWhen(Retry.backoff(this.options.maxRetries(), Duration.ofSeconds(5))) .blockLast(); - logger.info("Completed pulling the '{}' model", modelName); + this.logger.info("Completed pulling the '{}' model", modelName); // @formatter:on } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java index 11be453aa..e6f021008 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.management; /** diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java index dc7eed369..0a76de5a3 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java index e1c1bfac8..3ccf39b4c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.metadata; import java.util.Optional; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.util.Assert; @@ -30,25 +32,25 @@ public class OllamaChatUsage implements Usage { protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - public static OllamaChatUsage from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - return new OllamaChatUsage(response); - } - private final OllamaApi.ChatResponse response; public OllamaChatUsage(OllamaApi.ChatResponse response) { this.response = response; } + public static OllamaChatUsage from(OllamaApi.ChatResponse response) { + Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); + return new OllamaChatUsage(response); + } + @Override public Long getPromptTokens() { - return Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L); + return Optional.ofNullable(this.response.promptEvalCount()).map(Integer::longValue).orElse(0L); } @Override public Long getGenerationTokens() { - return Optional.ofNullable(response.evalCount()).map(Integer::longValue).orElse(0L); + return Optional.ofNullable(this.response.evalCount()).map(Integer::longValue).orElse(0L); } @Override diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java index 61ea60b33..c75ebaac1 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.metadata; import java.util.Optional; @@ -31,17 +32,17 @@ public class OllamaEmbeddingUsage implements Usage { protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - public static OllamaEmbeddingUsage from(EmbeddingsResponse response) { - Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null"); - return new OllamaEmbeddingUsage(response); - } - private Long promptTokens; public OllamaEmbeddingUsage(EmbeddingsResponse response) { this.promptTokens = Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L); } + public static OllamaEmbeddingUsage from(EmbeddingsResponse response) { + Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null"); + return new OllamaEmbeddingUsage(response); + } + @Override public Long getPromptTokens() { return this.promptTokens; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index a8845c0fd..f58413f26 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -2,12 +2,13 @@ package org.springframework.ai.ollama; import java.time.Duration; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.util.StringUtils; -import org.testcontainers.ollama.OllamaContainer; public class BaseOllamaIT { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index ae9d51fac..5738c337f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -29,19 +37,12 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.tool.MockWeatherService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -74,7 +75,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -99,7 +100,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -132,4 +133,4 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 4b2fac29e..88db3b96c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -33,20 +40,15 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.convert.support.DefaultConversionService; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -67,10 +69,10 @@ class OllamaChatModelIT extends BaseOllamaIT { @Test void autoPullModelTest() { - var modelManager = new OllamaModelManager(ollamaApi); + var modelManager = new OllamaModelManager(this.ollamaApi); assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue(); - String joke = ChatClient.create(chatModel) + String joke = ChatClient.create(this.chatModel) .prompt("Tell me a joke") .options(OllamaOptions.builder().withModel(ADDITIONAL_MODEL).build()) .call() @@ -97,13 +99,13 @@ class OllamaChatModelIT extends BaseOllamaIT { Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); // ollama specific options var ollamaOptions = new OllamaOptions().withLowVRAM(true); - response = chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); + response = this.chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -121,12 +123,12 @@ class OllamaChatModelIT extends BaseOllamaIT { Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Hello"), response.getResult().getOutput(), new UserMessage("Tell me just the names of those pirates."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard"); } @@ -134,7 +136,7 @@ class OllamaChatModelIT extends BaseOllamaIT { @Test void usageTest() { Prompt prompt = new Prompt("Tell me a joke"); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); Usage usage = response.getMetadata().getUsage(); assertThat(usage).isNotNull(); @@ -175,7 +177,7 @@ class OllamaChatModelIT extends BaseOllamaIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result).isNotNull(); @@ -184,9 +186,6 @@ class OllamaChatModelIT extends BaseOllamaIT { assertThat((String) result.get("B")).containsIgnoringCase("blue"); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); @@ -198,7 +197,7 @@ class OllamaChatModelIT extends BaseOllamaIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -217,7 +216,7 @@ class OllamaChatModelIT extends BaseOllamaIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -233,6 +232,10 @@ class OllamaChatModelIT extends BaseOllamaIT { assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { @@ -255,4 +258,4 @@ class OllamaChatModelIT extends BaseOllamaIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 5d8956552..0cb22784c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -31,9 +35,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; import org.springframework.util.MimeTypeUtils; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertThrows; @@ -57,7 +58,7 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT { var userMessage = new UserMessage("Explain what do you see in this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt(List.of(userMessage)))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt(List.of(userMessage)))); } @Test @@ -67,7 +68,7 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT { var userMessage = new UserMessage("Explain what do you see in this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); @@ -91,4 +92,4 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 4e9f80313..5254f3e01 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -33,10 +39,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -79,7 +81,7 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -103,7 +105,7 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -124,7 +126,7 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index cdb829953..5e8e74107 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; @@ -39,7 +41,7 @@ public class OllamaChatRequestTests { @Test public void createRequestWithDefaultOptions() { - var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), false); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content"), false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -57,7 +59,7 @@ public class OllamaChatRequestTests { // Runtime options should override the default options. OllamaOptions promptOptions = new OllamaOptions().withTemperature(0.8).withTopP(0.5).withNumGPU(2); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -65,11 +67,11 @@ public class OllamaChatRequestTests { assertThat(request.model()).isEqualTo("MODEL_NAME"); assertThat(request.options().get("temperature")).isEqualTo(0.8); assertThat(request.options().get("top_k")).isEqualTo(99); // still the default - // value. + // value. assertThat(request.options().get("num_gpu")).isEqualTo(2); assertThat(request.options().get("top_p")).isEqualTo(0.5); // new field introduced - // by the - // promptOptions. + // by the + // promptOptions. } @Test @@ -82,7 +84,7 @@ public class OllamaChatRequestTests { .withTopP(0.6) .build(); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -100,7 +102,7 @@ public class OllamaChatRequestTests { // Ollama runtime options. OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL"); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java index ae0612ac3..00322204e 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,25 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.PullModelStrategy; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -52,8 +54,8 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { @Test void embeddings() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest( + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest( List.of("Hello World", "Something else"), OllamaOptions.builder().withTruncate(false).build())); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); @@ -64,18 +66,18 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @Test void autoPullModelAtStartupTime() { var model = "all-minilm"; - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - var modelManager = new OllamaModelManager(ollamaApi); + var modelManager = new OllamaModelManager(this.ollamaApi); assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "Something else"), OllamaOptions.builder().withModel(model).withTruncate(false).build())); @@ -88,7 +90,7 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); modelManager.deleteModel(ADDITIONAL_MODEL); } @@ -115,4 +117,4 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java index aaf786ff2..ad3ebc5a7 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; - import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -36,8 +39,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -63,13 +64,13 @@ public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java index b77be0a60..0afb8c247 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java @@ -1,26 +1,32 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.ollama; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -30,10 +36,6 @@ import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsRequest; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; import org.springframework.ai.ollama.api.OllamaOptions; -import java.time.Duration; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.when; @@ -54,7 +56,7 @@ public class OllamaEmbeddingModelTests { @Test public void options() { - when(ollamaApi.embed(embeddingsRequestCaptor.capture())) + when(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture())) .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }), 0L, 0L, 0)) .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2", @@ -64,7 +66,7 @@ public class OllamaEmbeddingModelTests { var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build(); var embeddingModel = OllamaEmbeddingModel.builder() - .withOllamaApi(ollamaApi) + .withOllamaApi(this.ollamaApi) .withDefaultOptions(defaultOptions) .build(); @@ -80,11 +82,11 @@ public class OllamaEmbeddingModelTests { assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME"); - assertThat(embeddingsRequestCaptor.getValue().keepAlive()).isNull(); - assertThat(embeddingsRequestCaptor.getValue().truncate()).isNull(); - assertThat(embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input1", "Input2", "Input3")); - assertThat(embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of()); - assertThat(embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); + assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isNull(); + assertThat(this.embeddingsRequestCaptor.getValue().truncate()).isNull(); + assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input1", "Input2", "Input3")); + assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of()); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); // Tests runtime options var runtimeOptions = OllamaOptions.builder() @@ -105,11 +107,11 @@ public class OllamaEmbeddingModelTests { assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2"); - assertThat(embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofMinutes(10)); - assertThat(embeddingsRequestCaptor.getValue().truncate()).isFalse(); - assertThat(embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input4", "Input5", "Input6")); - assertThat(embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of("main_gpu", 666)); - assertThat(embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); + assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofMinutes(10)); + assertThat(this.embeddingsRequestCaptor.getValue().truncate()).isFalse(); + assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input4", "Input5", "Input6")); + assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of("main_gpu", 666)); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java index dfa8c9f22..309ebc2eb 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -38,7 +40,7 @@ public class OllamaEmbeddingRequestTests { @Test public void ollamaEmbeddingRequestDefaultOptions() { - var request = embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null); + var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null); assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); assertThat(request.options().get("num_gpu")).isEqualTo(1); @@ -56,7 +58,7 @@ public class OllamaEmbeddingRequestTests { .withUseMMap(true)// .withNumGPU(2); - var request = embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); + var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); assertThat(request.options().get("num_gpu")).isEqualTo(2); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 8a13c29b5..1e2bf625f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama; import org.testcontainers.utility.DockerImageName; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java index 4b7e3f49d..3b030e8c6 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index f9ed57882..cbf53fb3e 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.ChatResponse; @@ -27,12 +35,6 @@ import org.springframework.ai.ollama.api.OllamaApi.GenerateRequest; import org.springframework.ai.ollama.api.OllamaApi.GenerateResponse; import org.springframework.ai.ollama.api.OllamaApi.Message; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; - -import java.io.IOException; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -137,4 +139,4 @@ public class OllamaApiIT extends BaseOllamaIT { assertThat(response.totalDuration()).isGreaterThan(1); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java index bc4e878ab..f56525701 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiModelsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.ollama.api; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.ollama.api; import java.io.IOException; import java.time.Duration; @@ -23,9 +22,12 @@ import java.time.Duration; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ollama.BaseOllamaIT; import org.springframework.http.HttpStatus; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for the Ollama APIs to manage models. @@ -98,4 +100,4 @@ public class OllamaApiModelsIT extends BaseOllamaIT { assertThat(listModelResponse.models().stream().anyMatch(model -> model.name().contains(MODEL))).isTrue(); } -} \ No newline at end of file +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index faffdf242..205df66fe 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api; -import org.junit.jupiter.api.Test; - import java.util.List; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java index 64cb56fd6..c732a8e5e 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,29 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.ollama.api.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Christian Tzolov */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -63,28 +71,23 @@ public class MockWeatherService implements Function + + 4.0.0 diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java index 3ec1ad510..5e1e1619a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai; public interface ImageResponseMetadata { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java index 13057cb1a..73813b383 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,6 +19,8 @@ package org.springframework.ai.openai; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; @@ -33,7 +35,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; /** * OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}. @@ -46,6 +47,12 @@ import reactor.core.publisher.Flux; */ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel { + /** + * The speed of the default voice synthesis. + * @see OpenAiAudioSpeechOptions + */ + private static final Float SPEED = 1.0f; + private final Logger logger = LoggerFactory.getLogger(getClass()); /** @@ -53,12 +60,6 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel */ private final OpenAiAudioSpeechOptions defaultOptions; - /** - * The speed of the default voice synthesis. - * @see OpenAiAudioSpeechOptions - */ - private static final Float SPEED = 1.0f; - /** * The retry template used to retry the OpenAI Audio API calls. */ @@ -131,7 +132,7 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel var speech = speechEntity.getBody(); if (speech == null) { - logger.warn("No speech response returned for speechRequest: {}", speechRequest); + this.logger.warn("No speech response returned for speechRequest: {}", speechRequest); return new SpeechResponse(new Speech(new byte[0])); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java index 8d6ca7c9d..47cfc153c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,7 @@ package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.model.ModelOptions; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; @@ -70,137 +71,150 @@ public class OpenAiAudioSpeechOptions implements ModelOptions { return new Builder(); } - public static class Builder { - - private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions(); - - public Builder withModel(String model) { - options.model = model; - return this; - } - - public Builder withInput(String input) { - options.input = input; - return this; - } - - public Builder withVoice(Voice voice) { - options.voice = voice; - return this; - } - - public Builder withResponseFormat(AudioResponseFormat responseFormat) { - options.responseFormat = responseFormat; - return this; - } - - public Builder withSpeed(Float speed) { - options.speed = speed; - return this; - } - - public OpenAiAudioSpeechOptions build() { - return options; - } - - } - public String getModel() { - return model; - } - - public String getInput() { - return input; - } - - public Voice getVoice() { - return voice; - } - - public AudioResponseFormat getResponseFormat() { - return responseFormat; - } - - public Float getSpeed() { - return speed; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((input == null) ? 0 : input.hashCode()); - result = prime * result + ((voice == null) ? 0 : voice.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((speed == null) ? 0 : speed.hashCode()); - return result; + return this.model; } public void setModel(String model) { this.model = model; } + public String getInput() { + return this.input; + } + public void setInput(String input) { this.input = input; } + public Voice getVoice() { + return this.voice; + } + public void setVoice(Voice voice) { this.voice = voice; } + public AudioResponseFormat getResponseFormat() { + return this.responseFormat; + } + public void setResponseFormat(AudioResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public Float getSpeed() { + return this.speed; + } + public void setSpeed(Float speed) { this.speed = speed; } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.input == null) ? 0 : this.input.hashCode()); + result = prime * result + ((this.voice == null) ? 0 : this.voice.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.speed == null) ? 0 : this.speed.hashCode()); + return result; + } + @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } OpenAiAudioSpeechOptions other = (OpenAiAudioSpeechOptions) obj; - if (model == null) { - if (other.model != null) + if (this.model == null) { + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; - if (input == null) { - if (other.input != null) + } + if (this.input == null) { + if (other.input != null) { return false; + } } - else if (!input.equals(other.input)) + else if (!this.input.equals(other.input)) { return false; - if (voice == null) { - if (other.voice != null) + } + if (this.voice == null) { + if (other.voice != null) { return false; + } } - else if (!voice.equals(other.voice)) + else if (!this.voice.equals(other.voice)) { return false; - if (responseFormat == null) { - if (other.responseFormat != null) + } + if (this.responseFormat == null) { + if (other.responseFormat != null) { return false; + } } - else if (!responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) { return false; - if (speed == null) { + } + if (this.speed == null) { return other.speed == null; } - else - return speed.equals(other.speed); + else { + return this.speed.equals(other.speed); + } } @Override public String toString() { - return "OpenAiAudioSpeechOptions{" + "model='" + model + '\'' + ", input='" + input + '\'' + ", voice='" + voice - + '\'' + ", responseFormat='" + responseFormat + '\'' + ", speed=" + speed + '}'; + return "OpenAiAudioSpeechOptions{" + "model='" + this.model + '\'' + ", input='" + this.input + '\'' + + ", voice='" + this.voice + '\'' + ", responseFormat='" + this.responseFormat + '\'' + ", speed=" + + this.speed + '}'; + } + + public static class Builder { + + private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions(); + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withInput(String input) { + this.options.input = input; + return this; + } + + public Builder withVoice(Voice voice) { + this.options.voice = voice; + return this; + } + + public Builder withResponseFormat(AudioResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withSpeed(Float speed) { + this.options.speed = speed; + return this; + } + + public OpenAiAudioSpeechOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java index fbf51bb78..8aa43728b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,34 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/* -* Copyright 2024-2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ package org.springframework.ai.openai; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.audio.transcription.AudioTranscription; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; +import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.model.Model; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; -import org.springframework.ai.audio.transcription.AudioTranscription; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; -import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; @@ -133,7 +118,7 @@ public class OpenAiAudioTranscriptionModel implements Model getLogitBias() { + return this.logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogprobs() { + return this.topLogprobs; + } + + public void setTopLogprobs(Integer topLogprobs) { + this.topLogprobs = topLogprobs; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Integer getMaxCompletionTokens() { + return this.maxCompletionTokens; + } + + public void setMaxCompletionTokens(Integer maxCompletionTokens) { + this.maxCompletionTokens = maxCompletionTokens; + } + + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public StreamOptions getStreamOptions() { + return this.streamOptions; + } + + public void setStreamOptions(StreamOptions streamOptions) { + this.streamOptions = streamOptions; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public Boolean getParallelToolCalls() { + return this.parallelToolCalls; + } + + public void setParallelToolCalls(Boolean parallelToolCalls) { + this.parallelToolCalls = parallelToolCalls; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + public Map getHttpHeaders() { + return this.httpHeaders; + } + + public void setHttpHeaders(Map httpHeaders) { + this.httpHeaders = httpHeaders; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public OpenAiChatOptions copy() { + return OpenAiChatOptions.fromOptions(this); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, + this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, + this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, + this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, + this.proxyToolCalls, this.toolContext); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OpenAiChatOptions other = (OpenAiChatOptions) o; + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) + && Objects.equals(this.topLogprobs, other.topLogprobs) + && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.maxCompletionTokens, other.maxCompletionTokens) + && Objects.equals(this.n, other.n) && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.responseFormat, other.responseFormat) + && Objects.equals(this.streamOptions, other.streamOptions) && Objects.equals(this.seed, other.seed) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) + && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) + && Objects.equals(this.parallelToolCalls, other.parallelToolCalls) + && Objects.equals(this.functionCallbacks, other.functionCallbacks) + && Objects.equals(this.functions, other.functions) + && Objects.equals(this.httpHeaders, other.httpHeaders) + && Objects.equals(this.toolContext, other.toolContext) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls); + } + + @Override + public String toString() { + return "OpenAiChatOptions: " + ModelOptionsUtils.toJsonString(this); + } + public static class Builder { protected OpenAiChatOptions options; @@ -357,307 +663,4 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { } - public Boolean getStreamUsage() { - return this.streamOptions != null; - } - - public void setStreamUsage(Boolean enableStreamUsage) { - this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; - } - - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - - public void setFrequencyPenalty(Double frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - } - - public Map getLogitBias() { - return this.logitBias; - } - - public void setLogitBias(Map logitBias) { - this.logitBias = logitBias; - } - - public Boolean getLogprobs() { - return this.logprobs; - } - - public void setLogprobs(Boolean logprobs) { - this.logprobs = logprobs; - } - - public Integer getTopLogprobs() { - return this.topLogprobs; - } - - public void setTopLogprobs(Integer topLogprobs) { - this.topLogprobs = topLogprobs; - } - - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - - public void setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - } - - public Integer getMaxCompletionTokens() { - return maxCompletionTokens; - } - - public void setMaxCompletionTokens(Integer maxCompletionTokens) { - this.maxCompletionTokens = maxCompletionTokens; - } - - public Integer getN() { - return this.n; - } - - public void setN(Integer n) { - this.n = n; - } - - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - - public void setPresencePenalty(Double presencePenalty) { - this.presencePenalty = presencePenalty; - } - - public ResponseFormat getResponseFormat() { - return this.responseFormat; - } - - public void setResponseFormat(ResponseFormat responseFormat) { - this.responseFormat = responseFormat; - } - - public StreamOptions getStreamOptions() { - return streamOptions; - } - - public void setStreamOptions(StreamOptions streamOptions) { - this.streamOptions = streamOptions; - } - - public Integer getSeed() { - return this.seed; - } - - public void setSeed(Integer seed) { - this.seed = seed; - } - - @Override - @JsonIgnore - public List getStopSequences() { - return getStop(); - } - - @JsonIgnore - public void setStopSequences(List stopSequences) { - setStop(stopSequences); - } - - public List getStop() { - return this.stop; - } - - public void setStop(List stop) { - this.stop = stop; - } - - @Override - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - @Override - public Double getTopP() { - return this.topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public List getTools() { - return this.tools; - } - - public void setTools(List tools) { - this.tools = tools; - } - - public String getToolChoice() { - return this.toolChoice; - } - - @Override - public Boolean getProxyToolCalls() { - return this.proxyToolCalls; - } - - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; - } - - public void setToolChoice(String toolChoice) { - this.toolChoice = toolChoice; - } - - public String getUser() { - return this.user; - } - - public void setUser(String user) { - this.user = user; - } - - public Boolean getParallelToolCalls() { - return this.parallelToolCalls; - } - - public void setParallelToolCalls(Boolean parallelToolCalls) { - this.parallelToolCalls = parallelToolCalls; - } - - @Override - public List getFunctionCallbacks() { - return this.functionCallbacks; - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; - } - - @Override - public Set getFunctions() { - return functions; - } - - public void setFunctions(Set functionNames) { - this.functions = functionNames; - } - - public Map getHttpHeaders() { - return this.httpHeaders; - } - - public void setHttpHeaders(Map httpHeaders) { - this.httpHeaders = httpHeaders; - } - - @Override - @JsonIgnore - public Integer getTopK() { - return null; - } - - @Override - public Map getToolContext() { - return this.toolContext; - } - - @Override - public void setToolContext(Map toolContext) { - this.toolContext = toolContext; - } - - @Override - public OpenAiChatOptions copy() { - return OpenAiChatOptions.fromOptions(this); - } - - public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { - return OpenAiChatOptions.builder() - .withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withLogitBias(fromOptions.getLogitBias()) - .withLogprobs(fromOptions.getLogprobs()) - .withTopLogprobs(fromOptions.getTopLogprobs()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withMaxCompletionTokens(fromOptions.getMaxCompletionTokens()) - .withN(fromOptions.getN()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStreamUsage(fromOptions.getStreamUsage()) - .withSeed(fromOptions.getSeed()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withUser(fromOptions.getUser()) - .withParallelToolCalls(fromOptions.getParallelToolCalls()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withHttpHeaders(fromOptions.getHttpHeaders()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - - @Override - public int hashCode() { - return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, - this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, - this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, - this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, - this.proxyToolCalls, this.toolContext); - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - OpenAiChatOptions other = (OpenAiChatOptions) o; - return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) - && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) - && Objects.equals(this.topLogprobs, other.topLogprobs) - && Objects.equals(this.maxTokens, other.maxTokens) - && Objects.equals(this.maxCompletionTokens, other.maxCompletionTokens) - && Objects.equals(this.n, other.n) && Objects.equals(this.presencePenalty, other.presencePenalty) - && Objects.equals(this.responseFormat, other.responseFormat) - && Objects.equals(this.streamOptions, other.streamOptions) && Objects.equals(this.seed, other.seed) - && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) - && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) - && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) - && Objects.equals(this.parallelToolCalls, other.parallelToolCalls) - && Objects.equals(this.functionCallbacks, other.functionCallbacks) - && Objects.equals(this.functions, other.functions) - && Objects.equals(this.httpHeaders, other.httpHeaders) - && Objects.equals(this.toolContext, other.toolContext) - && Objects.equals(this.proxyToolCalls, other.proxyToolCalls); - } - - @Override - public String toString() { - return "OpenAiChatOptions: " + ModelOptionsUtils.toJsonString(this); - } - } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 89f61224a..a0824cf56 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -27,9 +31,9 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; -import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; @@ -40,8 +44,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * Open AI Embedding Model implementation. * diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java index dbd0a2979..64f173eaf 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; @@ -51,40 +52,6 @@ public class OpenAiEmbeddingOptions implements EmbeddingOptions { return new Builder(); } - public static class Builder { - - protected OpenAiEmbeddingOptions options; - - public Builder() { - this.options = new OpenAiEmbeddingOptions(); - } - - public Builder withModel(String model) { - this.options.setModel(model); - return this; - } - - public Builder withEncodingFormat(String encodingFormat) { - this.options.setEncodingFormat(encodingFormat); - return this; - } - - public Builder withDimensions(Integer dimensions) { - this.options.dimensions = dimensions; - return this; - } - - public Builder withUser(String user) { - this.options.setUser(user); - return this; - } - - public OpenAiEmbeddingOptions build() { - return this.options; - } - - } - @Override public String getModel() { return this.model; @@ -119,4 +86,38 @@ public class OpenAiEmbeddingOptions implements EmbeddingOptions { this.user = user; } + public static class Builder { + + protected OpenAiEmbeddingOptions options; + + public Builder() { + this.options = new OpenAiEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withEncodingFormat(String encodingFormat) { + this.options.setEncodingFormat(encodingFormat); + return this; + } + + public Builder withDimensions(Integer dimensions) { + this.options.dimensions = dimensions; + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public OpenAiEmbeddingOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index 3fabb9f49..1da5a9b86 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -26,8 +30,8 @@ import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponseMetadata; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; -import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationContext; +import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -39,8 +43,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * OpenAiImageModel is a class that implements the ImageModel interface. It provides a * client for calling the OpenAI image generation API. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java index 949614dae..1bab88096 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; -import java.util.Objects; - /** * OpenAI Image API options. OpenAiImageOptions.java * @@ -99,60 +100,6 @@ public class OpenAiImageOptions implements ImageOptions { return new Builder(); } - public static class Builder { - - private final OpenAiImageOptions options; - - private Builder() { - this.options = new OpenAiImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withQuality(String quality) { - options.setQuality(quality); - return this; - } - - public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withStyle(String style) { - options.setStyle(style); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public OpenAiImageOptions build() { - return options; - } - - } - @Override public Integer getN() { return this.n; @@ -181,7 +128,7 @@ public class OpenAiImageOptions implements ImageOptions { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -247,10 +194,6 @@ public class OpenAiImageOptions implements ImageOptions { this.user = user; } - public void setSize(String size) { - this.size = size; - } - public String getSize() { if (this.size != null) { return this.size; @@ -258,28 +201,91 @@ public class OpenAiImageOptions implements ImageOptions { return (this.width != null && this.height != null) ? this.width + "x" + this.height : null; } + public void setSize(String size) { + this.size = size; + } + @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof OpenAiImageOptions that)) + } + if (!(o instanceof OpenAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(quality, that.quality) - && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.quality, that.quality) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.size, that.size) + && Objects.equals(this.style, that.style) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, quality, responseFormat, size, style, user); + return Objects.hash(this.n, this.model, this.width, this.height, this.quality, this.responseFormat, this.size, + this.style, this.user); } @Override public String toString() { - return "OpenAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" + height - + ", quality='" + quality + '\'' + ", responseFormat='" + responseFormat + '\'' + ", size='" + size - + '\'' + ", style='" + style + '\'' + ", user='" + user + '\'' + '}'; + return "OpenAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", quality='" + this.quality + '\'' + ", responseFormat='" + + this.responseFormat + '\'' + ", size='" + this.size + '\'' + ", style='" + this.style + '\'' + + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final OpenAiImageOptions options; + + private Builder() { + this.options = new OpenAiImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withQuality(String quality) { + this.options.setQuality(quality); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public OpenAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java index dbf662af0..f6719710c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,28 @@ package org.springframework.ai.openai; +import java.util.ArrayList; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.moderation.*; +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Generation; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationModel; +import org.springframework.ai.moderation.ModerationOptions; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * OpenAiModerationModel is a class that implements the ModerationModel interface. It * provides a client for calling the OpenAI moderation generation API. @@ -40,12 +49,12 @@ public class OpenAiModerationModel implements ModerationModel { private final Logger logger = LoggerFactory.getLogger(getClass()); - private OpenAiModerationOptions defaultOptions; - private final OpenAiModerationApi openAiModerationApi; private final RetryTemplate retryTemplate; + private OpenAiModerationOptions defaultOptions; + public OpenAiModerationModel(OpenAiModerationApi openAiModerationApi) { this(openAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE); } @@ -97,7 +106,7 @@ public class OpenAiModerationModel implements ModerationModel { OpenAiModerationApi.OpenAiModerationRequest openAiModerationRequest) { OpenAiModerationApi.OpenAiModerationResponse moderationApiResponse = moderationResponseEntity.getBody(); if (moderationApiResponse == null) { - logger.warn("No moderation response returned for request: {}", openAiModerationRequest); + this.logger.warn("No moderation response returned for request: {}", openAiModerationRequest); return new ModerationResponse(new Generation()); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java index 9abacec51..496882313 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.moderation.ModerationOptions; import org.springframework.ai.openai.api.OpenAiModerationApi; @@ -39,25 +41,6 @@ public class OpenAiModerationOptions implements ModerationOptions { return new Builder(); } - public static class Builder { - - private final OpenAiModerationOptions options; - - private Builder() { - this.options = new OpenAiModerationOptions(); - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public OpenAiModerationOptions build() { - return options; - } - - } - @Override public String getModel() { return this.model; @@ -67,4 +50,23 @@ public class OpenAiModerationOptions implements ModerationOptions { this.model = model; } + public static class Builder { + + private final OpenAiModerationOptions options; + + private Builder() { + this.options = new OpenAiModerationOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public OpenAiModerationOptions build() { + return this.options; + } + + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java index b355e1b20..3a4fe1fa5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.aot; +import java.util.Set; + import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -25,8 +28,6 @@ import org.springframework.aot.hint.TypeReference; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; -import java.util.Set; - import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 657f37a13..b4dd95d29 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; @@ -21,6 +22,12 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.common.OpenAiApiConstants; @@ -39,13 +46,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Single class implementation of the * OpenAI Chat Completion @@ -74,6 +74,8 @@ public class OpenAiApi { private final WebClient webClient; + private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); + /** * Create a new chat completion api with base URL set to https://api.openai.com * @param apiKey OpenAI apiKey. @@ -173,6 +175,155 @@ public class OpenAiApi { .build();// @formatter:on } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @param additionalHttpHeader Optional, additional HTTP headers to be added to the + * request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + + return this.restClient.post() + .uri(this.completionsPath) + .headers(headers -> headers.addAll(additionalHttpHeader)) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @param additionalHttpHeader Optional, additional HTTP headers to be added to the + * request. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri(this.completionsPath) + .headers(headers -> headers.addAll(additionalHttpHeader)) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + // Detect is the chunk is part of a streaming function call. + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + // Group all chunks belonging to the same function call. + // Flux -> Flux> + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + // Reduce the inner Flux window into a single + // Mono, + // Flux> -> Flux> + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + // Flux> -> Flux + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + // The input must not exceed the max input tokens for the model (8192 tokens for + // text-embedding-ada-002), cannot + // be an empty string, and any array must be 2048 dimensions or less. + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less"); + Assert.isTrue( + list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri(this.embeddingsPath) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + /** * OpenAI Chat Completion Models: * @@ -296,7 +447,7 @@ public class OpenAiApi { } public String getValue() { - return value; + return this.value; } @Override @@ -306,6 +457,79 @@ public class OpenAiApi { } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") + CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") + TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") + FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") + TOOL_CALL + + } + + /** + * OpenAI Embeddings Models: + *
Embeddings. + */ + public enum EmbeddingModel { + + /** + * Most capable embedding model for both english and non-english tasks. DIMENSION: + * 3072 + */ + TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"), + + /** + * Increased performance over 2nd generation ada embedding model. DIMENSION: 1536 + */ + TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), + + /** + * Most capable 2nd generation embedding model, replacing 16 first generation + * models. DIMENSION: 1536 + */ + TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + /** * Represents a tool the model may call. Currently, only functions are supported as a * tool. @@ -523,9 +747,9 @@ public class OpenAiApi { * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, maxCompletionTokens, n, presencePenalty, - responseFormat, seed, stop, stream, streamOptions, temperature, topP, - tools, toolChoice, parallelToolCalls, user); + return new ChatCompletionRequest(this.messages, this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, + this.responseFormat, this.seed, this.stop, this.stream, streamOptions, this.temperature, this.topP, + this.tools, this.toolChoice, this.parallelToolCalls, this.user); } /** @@ -559,7 +783,20 @@ public class OpenAiApi { public record ResponseFormat( @JsonProperty("type") Type type, @JsonProperty("json_schema") JsonSchema jsonSchema ) { - + + public ResponseFormat(Type type) { + this(type, (JsonSchema) null); + } + + public ResponseFormat(Type type, String schema) { + this(type, "custom_schema", schema, true); + } + + @ConstructorBinding + public ResponseFormat(Type type, String name, String schema, Boolean strict) { + this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null); + } + public enum Type { /** * Generates a text response. (default) @@ -604,19 +841,6 @@ public class OpenAiApi { } } - public ResponseFormat(Type type) { - this(type, (JsonSchema) null); - } - - public ResponseFormat(Type type, String schema) { - this(type, "custom_schema", schema, true); - } - - @ConstructorBinding - public ResponseFormat(Type type, String name, String schema, Boolean strict) { - this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null); - } - } /** @@ -658,6 +882,16 @@ public class OpenAiApi { @JsonProperty("tool_calls") List toolCalls, @JsonProperty("refusal") String refusal) {// @formatter:on + /** + * Create a chat completion message with the given content and role. All other + * fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null, null); + } + /** * Get message content as String. */ @@ -671,16 +905,6 @@ public class OpenAiApi { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other - * fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null); - } - /** * The role of the author of this message. */ @@ -725,19 +949,6 @@ public class OpenAiApi { @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { // @formatter:on - /** - * @param url Either a URL of the image or the base64 encoded image data. The - * base64 encoded image data must have a special prefix in the following - * format: "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } /** * Shortcut constructor for a text content. @@ -754,6 +965,22 @@ public class OpenAiApi { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. The + * base64 encoded image data must have a special prefix in the following + * format: "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + + } + } /** @@ -777,6 +1004,7 @@ public class OpenAiApi { public ToolCall(String id, String type, ChatCompletionFunction function) { this(null, id, type, function); } + } /** @@ -791,50 +1019,6 @@ public class OpenAiApi { @JsonProperty("name") String name, @JsonProperty("arguments") String arguments) {// @formatter:on } - } - - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") - STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") - LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") - CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") - TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") - FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") - TOOL_CALL } @@ -880,6 +1064,7 @@ public class OpenAiApi { @JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on } + } /** @@ -928,9 +1113,13 @@ public class OpenAiApi { @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List probBytes) {// @formatter:on } + } + } + // Embeddings API + /** * Usage statistics for the completion request. * @@ -1018,144 +1207,6 @@ public class OpenAiApi { @JsonProperty("delta") ChatCompletionMessage delta, @JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on } - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); - } - - /** - * Creates a model response for the given chat conversation. - * @param chatRequest The chat completion request. - * @param additionalHttpHeader Optional, additional HTTP headers to be added to the - * request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code - * and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); - - return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); - } - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @param additionalHttpHeader Optional, additional HTTP headers to be added to the - * request. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - // cancels the flux stream after the "[DONE]" is received. - .takeUntil(SSE_DONE_PREDICATE) - // filters out the "[DONE]" message. - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - // Detect is the chunk is part of a streaming function call. - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - // Group all chunks belonging to the same function call. - // Flux -> Flux> - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - // Reduce the inner Flux window into a single - // Mono, - // Flux> -> Flux> - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null, null), - (previous, current) -> this.chunkMerger.merge(previous, current)); - return List.of(monoChunk); - }) - // Flux> -> Flux - .flatMap(mono -> mono); - } - - // Embeddings API - - /** - * OpenAI Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * Most capable embedding model for both english and non-english tasks. DIMENSION: - * 3072 - */ - TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"), - - /** - * Increased performance over 2nd generation ada embedding model. DIMENSION: 1536 - */ - TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"), - - /** - * Most capable 2nd generation embedding model, replacing 16 first generation - * models. DIMENSION: 1536 - */ - TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } } @@ -1183,6 +1234,7 @@ public class OpenAiApi { public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } + } /** @@ -1227,6 +1279,7 @@ public class OpenAiApi { public EmbeddingRequest(T input) { this(input, DEFAULT_EMBEDDING_MODEL); } + } /** @@ -1246,45 +1299,4 @@ public class OpenAiApi { @JsonProperty("usage") Usage usage) {// @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or - * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single - * request, You can pass a {@link List} of {@link String} or {@link List} of - * {@link List} of tokens. For example: - * - *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - // The input must not exceed the max input tokens for the model (8192 tokens for - // text-embedding-ada-002), cannot - // be an empty string, and any array must be 2048 dimensions or less. - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less"); - Assert.isTrue( - list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri(this.embeddingsPath) - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 72161328f..a21702609 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.ByteArrayResource; @@ -33,13 +40,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - /** * Turn audio into text or text into audio. Based on * OpenAI Audio @@ -124,6 +124,125 @@ public class OpenAiAudioApi { this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(authHeaders).build(); } + /** + * Request to generates audio from the input text. + * @param requestBody The request body. + * @return Response entity containing the audio binary. + */ + public ResponseEntity createSpeech(SpeechRequest requestBody) { + return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class); + } + + /** + * Streams audio generated from the input text. + * + * This method sends a POST request to the OpenAI API to generate audio from the + * provided text. The audio is streamed back as a Flux of ResponseEntity objects, each + * containing a byte array of the audio data. + * @param requestBody The request body containing the details for the audio + * generation, such as the input text, model, voice, and response format. + * @return A Flux of ResponseEntity objects, each containing a byte array of the audio + * data. + */ + public Flux> stream(SpeechRequest requestBody) { + + return this.webClient.post() + .uri("/v1/audio/speech") + .body(Mono.just(requestBody), SpeechRequest.class) + .accept(MediaType.APPLICATION_OCTET_STREAM) + .exchangeToFlux(clientResponse -> { + HttpHeaders headers = clientResponse.headers().asHttpHeaders(); + return clientResponse.bodyToFlux(byte[].class) + .map(bytes -> ResponseEntity.ok().headers(headers).body(bytes)); + }); + } + + /** + * Transcribes audio into the input language. + * @param requestBody The request body. + * @return Response entity containing the transcribed text in either json or text + * format. + */ + public ResponseEntity createTranscription(TranscriptionRequest requestBody) { + return createTranscription(requestBody, requestBody.responseFormat().getResponseType()); + } + + /** + * Transcribes audio into the input language. The response type is specified by the + * responseType parameter. + * @param The response type. + * @param requestBody The request body. + * @param responseType The response type class. + * @return Response entity containing the transcribed text in the responseType format. + */ + public ResponseEntity createTranscription(TranscriptionRequest requestBody, Class responseType) { + + MultiValueMap multipartBody = new LinkedMultiValueMap<>(); + multipartBody.add("file", new ByteArrayResource(requestBody.file()) { + + @Override + public String getFilename() { + return "audio.webm"; + } + }); + multipartBody.add("model", requestBody.model()); + multipartBody.add("language", requestBody.language()); + multipartBody.add("prompt", requestBody.prompt()); + multipartBody.add("response_format", requestBody.responseFormat().getValue()); + multipartBody.add("temperature", requestBody.temperature()); + if (requestBody.granularityType() != null) { + Assert.isTrue(requestBody.responseFormat() == TranscriptResponseFormat.VERBOSE_JSON, + "response_format must be set to verbose_json to use timestamp granularities."); + multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue()); + } + + return this.restClient.post() + .uri("/v1/audio/transcriptions") + .body(multipartBody) + .retrieve() + .toEntity(responseType); + } + + /** + * Translates audio into English. + * @param requestBody The request body. + * @return Response entity containing the transcribed text in either json or text + * format. + */ + public ResponseEntity createTranslation(TranslationRequest requestBody) { + return createTranslation(requestBody, requestBody.responseFormat().getResponseType()); + } + + /** + * Translates audio into English. The response type is specified by the responseType + * parameter. + * @param The response type. + * @param requestBody The request body. + * @param responseType The response type class. + * @return Response entity containing the transcribed text in the responseType format. + */ + public ResponseEntity createTranslation(TranslationRequest requestBody, Class responseType) { + + MultiValueMap multipartBody = new LinkedMultiValueMap<>(); + multipartBody.add("file", new ByteArrayResource(requestBody.file()) { + + @Override + public String getFilename() { + return "audio.webm"; + } + }); + multipartBody.add("model", requestBody.model()); + multipartBody.add("prompt", requestBody.prompt()); + multipartBody.add("response_format", requestBody.responseFormat().getValue()); + multipartBody.add("temperature", requestBody.temperature()); + + return this.restClient.post() + .uri("/v1/audio/translations") + .body(multipartBody) + .retrieve() + .toEntity(responseType); + } + /** * TTS is an AI model that converts text to natural sounding spoken text. We offer two * different model variates, tts-1 is optimized for real time text to speech use cases @@ -156,6 +275,69 @@ public class OpenAiAudioApi { } + /** + * Whisper is a + * general-purpose speech recognition model. It is trained on a large dataset of + * diverse audio and is also a multi-task model that can perform multilingual speech + * recognition as well as speech translation and language identification. The Whisper + * v2-large model is currently available through our API with the whisper-1 model + * name. + */ + public enum WhisperModel { + + // @formatter:off + @JsonProperty("whisper-1") WHISPER_1("whisper-1"); + // @formatter:on + + public final String value; + + WhisperModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * The format of the transcript and translation outputs, in one of these options: + * json, text, srt, verbose_json, or vtt. Defaults to json. + */ + public enum TranscriptResponseFormat { + + // @formatter:off + @JsonProperty("json") JSON("json", StructuredResponse.class), + @JsonProperty("text") TEXT("text", String.class), + @JsonProperty("srt") SRT("srt", String.class), + @JsonProperty("verbose_json") VERBOSE_JSON("verbose_json", StructuredResponse.class), + @JsonProperty("vtt") VTT("vtt", String.class); + // @formatter:on + + public final String value; + + public final Class responseType; + + TranscriptResponseFormat(String value, Class responseType) { + this.value = value; + this.responseType = responseType; + } + + public boolean isJsonType() { + return this == JSON || this == VERBOSE_JSON; + } + + public String getValue() { + return this.value; + } + + public Class getResponseType() { + return this.responseType; + } + + } + /** * Request to generates audio from the input text. Reference: * Create @@ -181,12 +363,16 @@ public class OpenAiAudioApi { @JsonProperty("speed") Float speed) { // @formatter:on + public static Builder builder() { + return new Builder(); + } + /** * The voice to use for synthesis. */ public enum Voice { - // @formatter:off + // @formatter:off @JsonProperty("alloy") ALLOY("alloy"), @JsonProperty("echo") ECHO("echo"), @JsonProperty("fable") FABLE("fable"), @@ -232,10 +418,6 @@ public class OpenAiAudioApi { } - public static Builder builder() { - return new Builder(); - } - /** * Builder for the SpeechRequest. */ @@ -277,38 +459,13 @@ public class OpenAiAudioApi { } public SpeechRequest build() { - Assert.hasText(model, "model must not be empty"); - Assert.hasText(input, "input must not be empty"); + Assert.hasText(this.model, "model must not be empty"); + Assert.hasText(this.input, "input must not be empty"); return new SpeechRequest(this.model, this.input, this.voice, this.responseFormat, this.speed); } } - } - - /** - * Whisper is a - * general-purpose speech recognition model. It is trained on a large dataset of - * diverse audio and is also a multi-task model that can perform multilingual speech - * recognition as well as speech translation and language identification. The Whisper - * v2-large model is currently available through our API with the whisper-1 model - * name. - */ - public enum WhisperModel { - - // @formatter:off - @JsonProperty("whisper-1") WHISPER_1("whisper-1"); - // @formatter:on - - public final String value; - - WhisperModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } } @@ -347,6 +504,10 @@ public class OpenAiAudioApi { @JsonProperty("timestamp_granularities") GranularityType granularityType) { // @formatter:on + public static Builder builder() { + return new Builder(); + } + public enum GranularityType { // @formatter:off @@ -366,10 +527,6 @@ public class OpenAiAudioApi { } - public static Builder builder() { - return new Builder(); - } - public static class Builder { private byte[] file; @@ -431,42 +588,6 @@ public class OpenAiAudioApi { } } - } - - /** - * The format of the transcript and translation outputs, in one of these options: - * json, text, srt, verbose_json, or vtt. Defaults to json. - */ - public enum TranscriptResponseFormat { - - // @formatter:off - @JsonProperty("json") JSON("json", StructuredResponse.class), - @JsonProperty("text") TEXT("text", String.class), - @JsonProperty("srt") SRT("srt", String.class), - @JsonProperty("verbose_json") VERBOSE_JSON("verbose_json", StructuredResponse.class), - @JsonProperty("vtt") VTT("vtt", String.class); - // @formatter:on - - public final String value; - - public final Class responseType; - - public boolean isJsonType() { - return this == JSON || this == VERBOSE_JSON; - } - - TranscriptResponseFormat(String value, Class responseType) { - this.value = value; - this.responseType = responseType; - } - - public String getValue() { - return this.value; - } - - public Class getResponseType() { - return this.responseType; - } } @@ -537,15 +658,16 @@ public class OpenAiAudioApi { } public TranslationRequest build() { - Assert.notNull(file, "file must not be null"); - Assert.hasText(model, "model must not be empty"); - Assert.notNull(responseFormat, "response_format must not be null"); + Assert.notNull(this.file, "file must not be null"); + Assert.hasText(this.model, "model must not be empty"); + Assert.notNull(this.responseFormat, "response_format must not be null"); return new TranslationRequest(this.file, this.model, this.prompt, this.responseFormat, this.temperature); } } + } /** @@ -619,123 +741,7 @@ public class OpenAiAudioApi { @JsonProperty("no_speech_prob") Float noSpeechProb) { // @formatter:on } - } - /** - * Request to generates audio from the input text. - * @param requestBody The request body. - * @return Response entity containing the audio binary. - */ - public ResponseEntity createSpeech(SpeechRequest requestBody) { - return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class); - } - - /** - * Streams audio generated from the input text. - * - * This method sends a POST request to the OpenAI API to generate audio from the - * provided text. The audio is streamed back as a Flux of ResponseEntity objects, each - * containing a byte array of the audio data. - * @param requestBody The request body containing the details for the audio - * generation, such as the input text, model, voice, and response format. - * @return A Flux of ResponseEntity objects, each containing a byte array of the audio - * data. - */ - public Flux> stream(SpeechRequest requestBody) { - - return webClient.post() - .uri("/v1/audio/speech") - .body(Mono.just(requestBody), SpeechRequest.class) - .accept(MediaType.APPLICATION_OCTET_STREAM) - .exchangeToFlux(clientResponse -> { - HttpHeaders headers = clientResponse.headers().asHttpHeaders(); - return clientResponse.bodyToFlux(byte[].class) - .map(bytes -> ResponseEntity.ok().headers(headers).body(bytes)); - }); - } - - /** - * Transcribes audio into the input language. - * @param requestBody The request body. - * @return Response entity containing the transcribed text in either json or text - * format. - */ - public ResponseEntity createTranscription(TranscriptionRequest requestBody) { - return createTranscription(requestBody, requestBody.responseFormat().getResponseType()); - } - - /** - * Transcribes audio into the input language. The response type is specified by the - * responseType parameter. - * @param The response type. - * @param requestBody The request body. - * @param responseType The response type class. - * @return Response entity containing the transcribed text in the responseType format. - */ - public ResponseEntity createTranscription(TranscriptionRequest requestBody, Class responseType) { - - MultiValueMap multipartBody = new LinkedMultiValueMap<>(); - multipartBody.add("file", new ByteArrayResource(requestBody.file()) { - @Override - public String getFilename() { - return "audio.webm"; - } - }); - multipartBody.add("model", requestBody.model()); - multipartBody.add("language", requestBody.language()); - multipartBody.add("prompt", requestBody.prompt()); - multipartBody.add("response_format", requestBody.responseFormat().getValue()); - multipartBody.add("temperature", requestBody.temperature()); - if (requestBody.granularityType() != null) { - Assert.isTrue(requestBody.responseFormat() == TranscriptResponseFormat.VERBOSE_JSON, - "response_format must be set to verbose_json to use timestamp granularities."); - multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue()); - } - - return this.restClient.post() - .uri("/v1/audio/transcriptions") - .body(multipartBody) - .retrieve() - .toEntity(responseType); - } - - /** - * Translates audio into English. - * @param requestBody The request body. - * @return Response entity containing the transcribed text in either json or text - * format. - */ - public ResponseEntity createTranslation(TranslationRequest requestBody) { - return createTranslation(requestBody, requestBody.responseFormat().getResponseType()); - } - - /** - * Translates audio into English. The response type is specified by the responseType - * parameter. - * @param The response type. - * @param requestBody The request body. - * @param responseType The response type class. - * @return Response entity containing the transcribed text in the responseType format. - */ - public ResponseEntity createTranslation(TranslationRequest requestBody, Class responseType) { - - MultiValueMap multipartBody = new LinkedMultiValueMap<>(); - multipartBody.add("file", new ByteArrayResource(requestBody.file()) { - @Override - public String getFilename() { - return "audio.webm"; - } - }); - multipartBody.add("model", requestBody.model()); - multipartBody.add("prompt", requestBody.prompt()); - multipartBody.add("response_format", requestBody.responseFormat().getValue()); - multipartBody.add("temperature", requestBody.temperature()); - - return this.restClient.post() - .uri("/v1/audio/translations") - .body(multipartBody) - .retrieve() - .toEntity(responseType); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index 698bbbae5..c53405407 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.MediaType; @@ -28,9 +32,6 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * OpenAI Image API. * @@ -95,6 +96,17 @@ public class OpenAiImageApi { // @formatter:on } + public ResponseEntity createImage(OpenAiImageRequest openAiImageRequest) { + Assert.notNull(openAiImageRequest, "Image request cannot be null."); + Assert.hasLength(openAiImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/images/generations") + .body(openAiImageRequest) + .retrieve() + .toEntity(OpenAiImageResponse.class); + } + /** * OpenAI Image API model. * DALL·E @@ -147,24 +159,12 @@ public class OpenAiImageApi { @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url, - @JsonProperty("b64_json") String b64Json, - @JsonProperty("revised_prompt") String revisedPrompt) { - } // @formatter:onn - public ResponseEntity createImage(OpenAiImageRequest openAiImageRequest) { - Assert.notNull(openAiImageRequest, "Image request cannot be null."); - Assert.hasLength(openAiImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { - return this.restClient.post() - .uri("v1/images/generations") - .body(openAiImageRequest) - .retrieve() - .toEntity(OpenAiImageResponse.class); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index 092cc62bc..02e2b3ca1 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.retry.RetryUtils; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -22,11 +28,6 @@ import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; - /** * OpenAI Moderation API. * @@ -36,10 +37,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; */ public class OpenAiModerationApi { - private static final String DEFAULT_BASE_URL = "https://api.openai.com"; - public static final String DEFAULT_MODERATION_MODEL = "text-moderation-latest"; + private static final String DEFAULT_BASE_URL = "https://api.openai.com"; + private final RestClient restClient; private final ObjectMapper objectMapper; @@ -63,6 +64,17 @@ public class OpenAiModerationApi { }).defaultStatusHandler(responseErrorHandler).build(); } + public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { + Assert.notNull(openAiModerationRequest, "Moderation request cannot be null."); + Assert.hasLength(openAiModerationRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/moderations") + .body(openAiModerationRequest) + .retrieve() + .toEntity(OpenAiModerationResponse.class); + } + // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) public record OpenAiModerationRequest ( @@ -82,6 +94,7 @@ public class OpenAiModerationApi { @JsonProperty("results") OpenAiModerationResult[] results) { } + @JsonInclude(JsonInclude.Include.NON_NULL) public record OpenAiModerationResult( @JsonProperty("flagged") boolean flagged, @@ -89,6 +102,7 @@ public class OpenAiModerationApi { @JsonProperty("category_scores") CategoryScores categoryScores) { } + @JsonInclude(JsonInclude.Include.NON_NULL) public record Categories( @JsonProperty("sexual") boolean sexual, @@ -119,25 +133,12 @@ public class OpenAiModerationApi { @JsonProperty("violence") double violence) { } - - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url, - @JsonProperty("b64_json") String b64Json, - @JsonProperty("revised_prompt") String revisedPrompt) { - } // @formatter:onn - public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { - Assert.notNull(openAiModerationRequest, "Moderation request cannot be null."); - Assert.hasLength(openAiModerationRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { - return this.restClient.post() - .uri("v1/moderations") - .body(openAiModerationRequest) - .retrieve() - .toEntity(OpenAiModerationResponse.class); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 02bfd3108..9bdd2ea18 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.ArrayList; @@ -21,13 +22,13 @@ import java.util.List; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; -import org.springframework.ai.openai.api.OpenAiApi.LogProbs; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk.ChunkChoice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.openai.api.OpenAiApi.LogProbs; import org.springframework.ai.openai.api.OpenAiApi.Usage; import org.springframework.util.CollectionUtils; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java index a53bc0bf6..7d5e96171 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiClientErrorException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api.common; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java index ebc454421..81051cf7b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/OpenAiApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.api.common; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java index 592194021..93ae1cba3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; +import java.util.Arrays; +import java.util.Objects; + import org.springframework.ai.model.ModelResult; import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechMetadata; import org.springframework.lang.Nullable; -import java.util.Arrays; -import java.util.Objects; - /** * The Speech class represents the result of speech synthesis from an AI model. It * implements the ModelResult interface with the output type of byte array. @@ -46,7 +47,7 @@ public class Speech implements ModelResult { @Override public OpenAiAudioSpeechMetadata getMetadata() { - return speechMetadata != null ? speechMetadata : OpenAiAudioSpeechMetadata.NULL; + return this.speechMetadata != null ? this.speechMetadata : OpenAiAudioSpeechMetadata.NULL; } public Speech withSpeechMetadata(@Nullable OpenAiAudioSpeechMetadata speechMetadata) { @@ -56,21 +57,23 @@ public class Speech implements ModelResult { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Speech that)) + } + if (!(o instanceof Speech that)) { return false; - return Arrays.equals(audio, that.audio) && Objects.equals(speechMetadata, that.speechMetadata); + } + return Arrays.equals(this.audio, that.audio) && Objects.equals(this.speechMetadata, that.speechMetadata); } @Override public int hashCode() { - return Objects.hash(Arrays.hashCode(audio), speechMetadata); + return Objects.hash(Arrays.hashCode(this.audio), this.speechMetadata); } @Override public String toString() { - return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}'; + return "Speech{" + "text=" + this.audio + ", speechMetadata=" + this.speechMetadata + '}'; } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java index dcc96251b..dde419268 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; import java.util.Objects; @@ -41,7 +42,7 @@ public class SpeechMessage { * @return the text of this speech message */ public String getText() { - return text; + return this.text; } /** @@ -54,16 +55,18 @@ public class SpeechMessage { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechMessage that)) + } + if (!(o instanceof SpeechMessage that)) { return false; - return Objects.equals(text, that.text); + } + return Objects.equals(this.text, that.text); } @Override public int hashCode() { - return Objects.hash(text); + return Objects.hash(this.text); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java index 9d976fd75..f03370ce4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java index 8cb21684d..03fb07d6e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.speech; +import java.util.Objects; + import org.springframework.ai.model.ModelOptions; import org.springframework.ai.model.ModelRequest; import org.springframework.ai.openai.OpenAiAudioSpeechOptions; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - /** * The {@link SpeechPrompt} class represents a request to the OpenAI Text-to-Speech (TTS) * API. It contains a list of {@link SpeechMessage} objects, each representing a piece of @@ -33,10 +32,10 @@ import java.util.Objects; */ public class SpeechPrompt implements ModelRequest { - private OpenAiAudioSpeechOptions speechOptions; - private final SpeechMessage message; + private OpenAiAudioSpeechOptions speechOptions; + public SpeechPrompt(String instructions) { this(new SpeechMessage(instructions), OpenAiAudioSpeechOptions.builder().build()); } @@ -61,21 +60,23 @@ public class SpeechPrompt implements ModelRequest { @Override public ModelOptions getOptions() { - return speechOptions; + return this.speechOptions; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechPrompt that)) + } + if (!(o instanceof SpeechPrompt that)) { return false; - return Objects.equals(speechOptions, that.speechOptions) && Objects.equals(message, that.message); + } + return Objects.equals(this.speechOptions, that.speechOptions) && Objects.equals(this.message, that.message); } @Override public int hashCode() { - return Objects.hash(speechOptions, message); + return Objects.hash(this.speechOptions, this.message); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java index 028bbf228..5b92fe770 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.openai.audio.speech; -import org.springframework.ai.model.ModelResponse; -import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; - import java.util.Collections; import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; + /** * Creates a new instance of SpeechResponse with the given speech result. * @@ -60,32 +60,34 @@ public class SpeechResponse implements ModelResponse { @Override public Speech getResult() { - return speech; + return this.speech; } @Override public List getResults() { - return Collections.singletonList(speech); + return Collections.singletonList(this.speech); } @Override public OpenAiAudioSpeechResponseMetadata getMetadata() { - return speechResponseMetadata; + return this.speechResponseMetadata; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SpeechResponse that)) + } + if (!(o instanceof SpeechResponse that)) { return false; - return Objects.equals(speech, that.speech) - && Objects.equals(speechResponseMetadata, that.speechResponseMetadata); + } + return Objects.equals(this.speech, that.speech) + && Objects.equals(this.speechResponseMetadata, that.speechResponseMetadata); } @Override public int hashCode() { - return Objects.hash(speech, speechResponseMetadata); + return Objects.hash(this.speech, this.speechResponseMetadata); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java index a8ae06b07..92dcfa347 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,10 @@ package org.springframework.ai.openai.audio.speech; -import org.springframework.ai.model.StreamingModel; import reactor.core.publisher.Flux; +import org.springframework.ai.model.StreamingModel; + /** * The {@link StreamingSpeechModel} interface provides a way to interact with the OpenAI * Text-to-Speech (TTS) API using a streaming approach, allowing you to receive the diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java index bc4da401b..186095dc9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata; -import org.springframework.ai.image.ImageGenerationMetadata; - import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + public class OpenAiImageGenerationMetadata implements ImageGenerationMetadata { private String revisedPrompt; @@ -28,26 +29,28 @@ public class OpenAiImageGenerationMetadata implements ImageGenerationMetadata { } public String getRevisedPrompt() { - return revisedPrompt; + return this.revisedPrompt; } @Override public String toString() { - return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + this.revisedPrompt + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof OpenAiImageGenerationMetadata that)) + } + if (!(o instanceof OpenAiImageGenerationMetadata that)) { return false; - return Objects.equals(revisedPrompt, that.revisedPrompt); + } + return Objects.equals(this.revisedPrompt, that.revisedPrompt); } @Override public int hashCode() { - return Objects.hash(revisedPrompt); + return Objects.hash(this.revisedPrompt); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java index b56226948..71d1dcee4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiModerationGenerationMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,6 @@ package org.springframework.ai.openai.metadata; import org.springframework.ai.moderation.ModerationGenerationMetadata; -import java.util.Objects; - public class OpenAiModerationGenerationMetadata implements ModerationGenerationMetadata { public OpenAiModerationGenerationMetadata() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java index 7f5f214da..664de40a4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata; import java.time.Duration; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 46ec6ffb7..4e32bd153 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -32,10 +33,6 @@ import org.springframework.util.Assert; */ public class OpenAiUsage implements Usage { - public static OpenAiUsage from(OpenAiApi.Usage usage) { - return new OpenAiUsage(usage); - } - private final OpenAiApi.Usage usage; protected OpenAiUsage(OpenAiApi.Usage usage) { @@ -43,6 +40,10 @@ public class OpenAiUsage implements Usage { this.usage = usage; } + public static OpenAiUsage from(OpenAiApi.Usage usage) { + return new OpenAiUsage(usage); + } + protected OpenAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java index 85289d854..b6de47b4b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -28,6 +28,7 @@ public interface OpenAiAudioSpeechMetadata extends ResultMetadata { */ static OpenAiAudioSpeechMetadata create() { return new OpenAiAudioSpeechMetadata() { + }; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java index efcb6ebca..e90c4097d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,13 +19,10 @@ package org.springframework.ai.openai.metadata.audio; import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.model.MutableResponseMetadata; -import org.springframework.ai.model.ResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import java.util.HashMap; - /** * Audio speech metadata implementation for {@literal OpenAI}. * @@ -34,10 +31,22 @@ import java.util.HashMap; */ public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata { + public static final OpenAiAudioSpeechResponseMetadata NULL = new OpenAiAudioSpeechResponseMetadata() { + + }; + protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }"; - public static final OpenAiAudioSpeechResponseMetadata NULL = new OpenAiAudioSpeechResponseMetadata() { - }; + @Nullable + private RateLimit rateLimit; + + public OpenAiAudioSpeechResponseMetadata() { + this(null); + } + + public OpenAiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) { + this.rateLimit = rateLimit; + } public static OpenAiAudioSpeechResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { Assert.notNull(result, "OpenAI speech must not be null"); @@ -51,17 +60,6 @@ public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata { return speechResponseMetadata; } - @Nullable - private RateLimit rateLimit; - - public OpenAiAudioSpeechResponseMetadata() { - this(null); - } - - public OpenAiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) { - this.rateLimit = rateLimit; - } - @Nullable public RateLimit getRateLimit() { RateLimit rateLimit = this.rateLimit; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java index 7fc7d1755..106c9d726 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.audio; import org.springframework.ai.audio.transcription.AudioTranscriptionResponseMetadata; import org.springframework.ai.chat.metadata.EmptyRateLimit; import org.springframework.ai.chat.metadata.RateLimit; -import org.springframework.ai.model.MutableResponseMetadata; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.metadata.OpenAiRateLimit; import org.springframework.lang.Nullable; @@ -34,20 +34,11 @@ import org.springframework.util.Assert; */ public class OpenAiAudioTranscriptionResponseMetadata extends AudioTranscriptionResponseMetadata { - protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; - public static final OpenAiAudioTranscriptionResponseMetadata NULL = new OpenAiAudioTranscriptionResponseMetadata() { + }; - public static OpenAiAudioTranscriptionResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { - Assert.notNull(result, "OpenAI Transcription must not be null"); - return new OpenAiAudioTranscriptionResponseMetadata(); - } - - public static OpenAiAudioTranscriptionResponseMetadata from(String result) { - Assert.notNull(result, "OpenAI Transcription must not be null"); - return new OpenAiAudioTranscriptionResponseMetadata(); - } + protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; @Nullable private RateLimit rateLimit; @@ -60,6 +51,16 @@ public class OpenAiAudioTranscriptionResponseMetadata extends AudioTranscription this.rateLimit = rateLimit; } + public static OpenAiAudioTranscriptionResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + return new OpenAiAudioTranscriptionResponseMetadata(); + } + + public static OpenAiAudioTranscriptionResponseMetadata from(String result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + return new OpenAiAudioTranscriptionResponseMetadata(); + } + @Nullable public RateLimit getRateLimit() { RateLimit rateLimit = this.rateLimit; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java index 47d3d5f2d..5c6107c8e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.support; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index f472ad060..1d46556cc 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.metadata.support; import java.time.Duration; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index f91edff96..74ca86bef 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.util.List; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java index 7083d7980..f11613fa1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiImageOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 24be2e910..d9e6b6ca5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiImageApi; -import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java index 96a95ba4e..7c7c467e2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/TranscriptionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import org.junit.jupiter.api.Test; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.core.io.DefaultResourceLoader; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java index abe5b954d..2701a2fe8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.acme; import java.util.List; @@ -24,16 +25,16 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.document.Document; -import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.OpenAiTestConfiguration; -import org.springframework.ai.openai.OpenAiEmbeddingModel; -import org.springframework.ai.openai.testutils.AbstractIT; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.SimpleVectorStore; @@ -65,23 +66,23 @@ public class AcmeIT extends AbstractIT { @Test void beanTest() { - assertThat(bikesResource).isNotNull(); - assertThat(embeddingModel).isNotNull(); - assertThat(chatModel).isNotNull(); + assertThat(this.bikesResource).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); + assertThat(this.chatModel).isNotNull(); } // @Test void acmeChain() { // Step 1 - load documents - JsonReader jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription", "description"); + JsonReader jsonReader = new JsonReader(this.bikesResource, "name", "price", "shortDescription", "description"); var textSplitter = new TokenTextSplitter(); // Step 2 - Create embeddings and save to vector store logger.info("Creating Embeddings..."); - VectorStore vectorStore = new SimpleVectorStore(embeddingModel); + VectorStore vectorStore = new SimpleVectorStore(this.embeddingModel); vectorStore.accept(textSplitter.apply(jsonReader.get())); @@ -108,7 +109,7 @@ public class AcmeIT extends AbstractIT { logger.info("Asking AI generative to reply to question."); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); logger.info("AI responded."); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); evaluateQuestionAndAnswer(userQuery, response, true); } @@ -119,7 +120,7 @@ public class AcmeIT extends AbstractIT { .map(entry -> entry.getContent()) .collect(Collectors.joining(System.lineSeparator())); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemBikePrompt); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemBikePrompt); Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents)); return systemMessage; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java index e4399ec8b..4b54d809c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 17a002276..a07d400e1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api; import java.util.List; @@ -43,7 +44,7 @@ public class OpenAiApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = openAiApi.chatCompletionEntity( + ResponseEntity response = this.openAiApi.chatCompletionEntity( new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); assertThat(response).isNotNull(); @@ -53,7 +54,7 @@ public class OpenAiApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = openAiApi.chatCompletionStream( + Flux response = this.openAiApi.chatCompletionStream( new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); assertThat(response).isNotNull(); @@ -62,7 +63,7 @@ public class OpenAiApiIT { @Test void embeddings() { - ResponseEntity> response = openAiApi + ResponseEntity> response = this.openAiApi .embeddings(new OpenAiApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java index db41af1f0..88e5df176 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.api.tool; import java.util.function.Function; @@ -28,16 +29,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +71,25 @@ public class MockWeatherService implements Function T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -95,7 +104,7 @@ public class OpenAiApiToolFunctionCallIT { List.of(functionTool), ToolChoiceBuilder.AUTO); // List.of(functionTool), ToolChoiceBuilder.FUNCTION("getCurrentWeather")); - ResponseEntity chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.completionApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -116,7 +125,7 @@ public class OpenAiApiToolFunctionCallIT { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -126,9 +135,10 @@ public class OpenAiApiToolFunctionCallIT { var functionResponseRequest = new ChatCompletionRequest(messages, "gpt-4o", 0.5); - ResponseEntity chatCompletion2 = completionApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.completionApi + .chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); @@ -144,13 +154,4 @@ public class OpenAiApiToolFunctionCallIT { } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java index f774711a6..a5c4123a9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.api; import java.io.File; @@ -23,10 +24,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest; -import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest; -import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; -import org.springframework.ai.openai.api.OpenAiAudioApi.TranslationRequest; import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; +import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranslationRequest; import org.springframework.ai.openai.api.OpenAiAudioApi.TtsModel; import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel; import org.springframework.util.FileCopyUtils; @@ -45,7 +46,7 @@ public class OpenAiAudioApiIT { @Test void speechTranscriptionAndTranslation() throws IOException { - byte[] speech = audioApi + byte[] speech = this.audioApi .createSpeech(SpeechRequest.builder() .withModel(TtsModel.TTS_1_HD.getValue()) .withInput("Hello, my name is Chris and I love Spring A.I.") @@ -57,7 +58,7 @@ public class OpenAiAudioApiIT { FileCopyUtils.copy(speech, new File("target/speech.mp3")); - StructuredResponse translation = audioApi + StructuredResponse translation = this.audioApi .createTranslation( TranslationRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(), StructuredResponse.class) @@ -65,7 +66,7 @@ public class OpenAiAudioApiIT { assertThat(translation.text().replaceAll(",", "")).isEqualTo("Hello my name is Chris and I love Spring AI."); - StructuredResponse transcriptionEnglish = audioApi.createTranscription( + StructuredResponse transcriptionEnglish = this.audioApi.createTranscription( TranscriptionRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(), StructuredResponse.class) .getBody(); @@ -73,7 +74,7 @@ public class OpenAiAudioApiIT { assertThat(transcriptionEnglish.text().replaceAll(",", "")) .isEqualTo("Hello my name is Chris and I love Spring AI."); - StructuredResponse transcriptionDutch = audioApi + StructuredResponse transcriptionDutch = this.audioApi .createTranscription(TranscriptionRequest.builder().withFile(speech).withLanguage("nl").build(), StructuredResponse.class) .getBody(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java index 0ff96b259..780ab89e2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,17 +16,18 @@ package org.springframework.ai.openai.audio.speech; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.openai.OpenAiAudioSpeechOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.boot.test.context.SpringBootTest; -import reactor.core.publisher.Flux; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -38,7 +39,7 @@ class OpenAiSpeechModelIT extends AbstractIT { @Test void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { - Flux response = speechModel.stream("Today is a wonderful day to build something people love!"); + Flux response = this.speechModel.stream("Today is a wonderful day to build something people love!"); assertThat(response).isNotNull(); assertThat(response.collectList().block()).isNotNull(); System.out.println(response.collectList().block()); @@ -46,7 +47,7 @@ class OpenAiSpeechModelIT extends AbstractIT { @Test void shouldProduceAudioBytesDirectlyFromMessage() { - byte[] audioBytes = speechModel.call("Today is a wonderful day to build something people love!"); + byte[] audioBytes = this.speechModel.call("Today is a wonderful day to build something people love!"); assertThat(audioBytes).hasSizeGreaterThan(0); } @@ -61,7 +62,7 @@ class OpenAiSpeechModelIT extends AbstractIT { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = speechModel.call(speechPrompt); + SpeechResponse response = this.speechModel.call(speechPrompt); byte[] audioBytes = response.getResult().getOutput(); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput()).isNotEmpty(); @@ -79,7 +80,7 @@ class OpenAiSpeechModelIT extends AbstractIT { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = speechModel.call(speechPrompt); + SpeechResponse response = this.speechModel.call(speechPrompt); OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); assertThat(metadata).isNotNull(); assertThat(metadata.getRateLimit()).isNotNull(); @@ -100,7 +101,7 @@ class OpenAiSpeechModelIT extends AbstractIT { SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - Flux responseFlux = speechModel.stream(speechPrompt); + Flux responseFlux = this.speechModel.stream(speechPrompt); assertThat(responseFlux).isNotNull(); List responses = responseFlux.collectList().block(); assertThat(responses).isNotNull(); @@ -110,4 +111,4 @@ class OpenAiSpeechModelIT extends AbstractIT { }); } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java index 089c9c824..0371dd08b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelWithSpeechResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,8 +16,11 @@ package org.springframework.ai.openai.audio.speech; +import java.time.Duration; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioSpeechOptions; import org.springframework.ai.openai.api.OpenAiAudioApi; @@ -34,12 +37,10 @@ import org.springframework.http.MediaType; import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.time.Duration; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -48,10 +49,10 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat @RestClientTest(OpenAiSpeechModelWithSpeechResponseMetadataTests.Config.class) public class OpenAiSpeechModelWithSpeechResponseMetadataTests { - private static String TEST_API_KEY = "sk-1234567890"; - private static final Float SPEED = 1.0f; + private static String TEST_API_KEY = "sk-1234567890"; + @Autowired private OpenAiAudioSpeechModel openAiSpeechClient; @@ -60,7 +61,7 @@ public class OpenAiSpeechModelWithSpeechResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -77,7 +78,7 @@ public class OpenAiSpeechModelWithSpeechResponseMetadataTests { SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = openAiSpeechClient.call(speechPrompt); + SpeechResponse response = this.openAiSpeechClient.call(speechPrompt); byte[] audioBytes = response.getResult().getOutput(); assertThat(audioBytes).hasSizeGreaterThan(0); @@ -110,7 +111,7 @@ public class OpenAiSpeechModelWithSpeechResponseMetadataTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); httpHeaders.setContentType(MediaType.APPLICATION_OCTET_STREAM); - server.expect(requestTo("/v1/audio/speech")) + this.server.expect(requestTo("/v1/audio/speech")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess("Audio bytes as string", MediaType.APPLICATION_OCTET_STREAM).headers(httpHeaders)); @@ -132,4 +133,4 @@ public class OpenAiSpeechModelWithSpeechResponseMetadataTests { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java index f4a44d36a..bf252291c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.transcription; import org.junit.jupiter.api.Test; @@ -44,8 +45,9 @@ class OpenAiTranscriptionModelIT extends AbstractIT { .withResponseFormat(TranscriptResponseFormat.TEXT) .withTemperature(0f) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } @@ -60,8 +62,9 @@ class OpenAiTranscriptionModelIT extends AbstractIT { .withTemperature(0f) .withResponseFormat(responseFormat) .build(); - AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest); + AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, + transcriptionOptions); + AudioTranscriptionResponse response = this.transcriptionModel.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java index a7749a8f0..a1b23b4a7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.audio.transcription; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; +package org.springframework.ai.openai.audio.transcription; import java.time.Duration; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.audio.transcription.AudioTranscriptionMetadata; import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; @@ -46,6 +42,12 @@ import org.springframework.http.MediaType; import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + /** * @author Michael Lavelle */ @@ -62,7 +64,7 @@ public class OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -118,7 +120,7 @@ public class OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("/v1/audio/transcriptions")) + this.server.expect(requestTo("/v1/audio/transcriptions")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java index 4e7010035..1f93fe0f6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.audio.transcription; import org.junit.jupiter.api.Test; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java index 80320186d..1226618b6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -27,7 +28,7 @@ public class ActorsFilms { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -35,7 +36,7 @@ public class ActorsFilms { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -44,7 +45,7 @@ public class ActorsFilms { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 8a393c8aa..1fcb34af2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.net.MalformedURLException; import java.net.URL; import java.util.List; @@ -32,10 +29,12 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.model.Media; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; @@ -44,7 +43,8 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -73,41 +73,42 @@ public class MessageTypeContentTests { @BeforeEach public void beforeEach() { - chatModel = new OpenAiChatModel(openAiApi); + this.chatModel = new OpenAiChatModel(this.openAiApi); } @Test public void systemMessageSimpleContentType() { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); - chatModel.call(new Prompt(List.of(new SystemMessage("test message")))); + this.chatModel.call(new Prompt(List.of(new SystemMessage("test message")))); - validateStringContent(pomptCaptor.getValue()); - assertThat(headersCaptor.getValue()).isEmpty(); + validateStringContent(this.pomptCaptor.getValue()); + assertThat(this.headersCaptor.getValue()).isEmpty(); } @Test public void userMessageSimpleContentType() { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); - chatModel.call(new Prompt(List.of(new UserMessage("test message")))); + this.chatModel.call(new Prompt(List.of(new UserMessage("test message")))); - validateStringContent(pomptCaptor.getValue()); + validateStringContent(this.pomptCaptor.getValue()); } @Test public void streamUserMessageSimpleContentType() { - when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); + when(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture())) + .thenReturn(this.fluxResponse); - chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); + this.chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); - validateStringContent(pomptCaptor.getValue()); - assertThat(headersCaptor.getValue()).isEmpty(); + validateStringContent(this.pomptCaptor.getValue()); + assertThat(this.headersCaptor.getValue()).isEmpty(); } private void validateStringContent(ChatCompletionRequest chatCompletionRequest) { @@ -121,28 +122,29 @@ public class MessageTypeContentTests { @Test public void userMessageWithMediaType() throws MalformedURLException { - when(openAiApi.chatCompletionEntity(pomptCaptor.capture(), headersCaptor.capture())) + when(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .thenReturn(Mockito.mock(ResponseEntity.class)); URL mediaUrl = new URL("http://test"); - chatModel.call(new Prompt( + this.chatModel.call(new Prompt( List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))); - validateComplexContent(pomptCaptor.getValue()); + validateComplexContent(this.pomptCaptor.getValue()); } @Test public void streamUserMessageWithMediaType() throws MalformedURLException { - when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); + when(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture())) + .thenReturn(this.fluxResponse); URL mediaUrl = new URL("http://test"); - chatModel + this.chatModel .stream(new Prompt( List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))) .subscribe(); - validateComplexContent(pomptCaptor.getValue()); + validateComplexContent(this.pomptCaptor.getValue()); } private void validateComplexContent(ChatCompletionRequest chatCompletionRequest) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java index 7bb2a9813..cce923b87 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; +package org.springframework.ai.openai.chat; import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; @@ -33,6 +32,9 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertThrows; + /** * @author Christian Tzolov */ diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index 1c1087257..5caa848b9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -38,13 +47,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.BiFunction; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -112,7 +114,7 @@ class OpenAiChatModelFunctionCallingIT { List messages = new ArrayList<>(List.of(userMessage)); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -175,7 +177,7 @@ class OpenAiChatModelFunctionCallingIT { List messages = new ArrayList<>(List.of(userMessage)); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -205,4 +207,4 @@ class OpenAiChatModelFunctionCallingIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 66a1fd7b4..b12a10e61 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat; import java.io.IOException; import java.net.URL; @@ -34,9 +33,10 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -47,6 +47,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; @@ -60,7 +61,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -75,10 +76,10 @@ public class OpenAiChatModelIT extends AbstractIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); @@ -88,16 +89,16 @@ public class OpenAiChatModelIT extends AbstractIT { void testMessageHistory() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -111,7 +112,7 @@ public class OpenAiChatModelIT extends AbstractIT { StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - Flux chatResponseFlux = streamingChatModel.stream(prompt).doOnNext(chatResponse -> { + Flux chatResponseFlux = this.streamingChatModel.stream(prompt).doOnNext(chatResponse -> { String responseContent = chatResponse.getResults().get(0).getOutput().getContent(); answer.append(responseContent); }).doOnComplete(() -> { @@ -133,7 +134,7 @@ public class OpenAiChatModelIT extends AbstractIT { StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - ChatClient chatClient = ChatClient.builder(openAiChatModel).build(); + ChatClient chatClient = ChatClient.builder(this.openAiChatModel).build(); Flux chatResponseFlux = chatClient.prompt(prompt) .stream() @@ -159,7 +160,7 @@ public class OpenAiChatModelIT extends AbstractIT { StringBuilder answer = new StringBuilder(); CountDownLatch latch = new CountDownLatch(1); - ChatClient chatClient = ChatClient.builder(openAiChatModel).build(); + ChatClient chatClient = ChatClient.builder(this.openAiChatModel).build(); Flux chatResponseFlux = chatClient.prompt(prompt) .stream() @@ -178,10 +179,10 @@ public class OpenAiChatModelIT extends AbstractIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -247,7 +248,7 @@ public class OpenAiChatModelIT extends AbstractIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -266,14 +267,11 @@ public class OpenAiChatModelIT extends AbstractIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -286,7 +284,7 @@ public class OpenAiChatModelIT extends AbstractIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -307,7 +305,7 @@ public class OpenAiChatModelIT extends AbstractIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -339,7 +337,7 @@ public class OpenAiChatModelIT extends AbstractIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -364,7 +362,7 @@ public class OpenAiChatModelIT extends AbstractIT { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -390,7 +388,7 @@ public class OpenAiChatModelIT extends AbstractIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -406,7 +404,7 @@ public class OpenAiChatModelIT extends AbstractIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -421,7 +419,7 @@ public class OpenAiChatModelIT extends AbstractIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = streamingChatModel.stream(new Prompt(List.of(userMessage), + Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); String content = response.collectList() @@ -441,7 +439,7 @@ public class OpenAiChatModelIT extends AbstractIT { void validateCallResponseMetadata() { String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getName(); // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -456,4 +454,8 @@ public class OpenAiChatModelIT extends AbstractIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 362135eb4..36edf6952 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -38,9 +42,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; @@ -62,7 +63,7 @@ public class OpenAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -80,7 +81,7 @@ public class OpenAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -104,7 +105,7 @@ public class OpenAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -125,7 +126,7 @@ public class OpenAiChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java index c4e1198ee..dc43e4e7f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat; import java.util.ArrayList; import java.util.List; @@ -25,10 +24,16 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -37,8 +42,8 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.ToolCallHelper; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.ToolCallHelper; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -49,12 +54,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.util.CollectionUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -64,31 +64,6 @@ class OpenAiChatModelProxyToolCallsIT { private static final String DEFAULT_MODEL = "gpt-4o-mini"; - @Autowired - private OpenAiChatModel chatModel; - - // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality - // to help to implement the function call handling logic on the client side. - private ToolCallHelper toolCallHelper = new ToolCallHelper(); - - // Function which will be called by the AI model. - private String getWeatherInLocation(String location, String unit) { - - double temperature = 0; - - if (location.contains("Paris")) { - temperature = 15; - } - else if (location.contains("Tokyo")) { - temperature = 10; - } - else if (location.contains("San Francisco")) { - temperature = 30; - } - - return String.format("The weather in %s is %s%s", location, temperature, unit); - } - FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation", "Get the weather in location", """ { @@ -107,13 +82,48 @@ class OpenAiChatModelProxyToolCallsIT { } """); + @Autowired + private OpenAiChatModel chatModel; + + // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality + // to help to implement the function call handling logic on the client side. + private ToolCallHelper toolCallHelper = new ToolCallHelper(); + + @SuppressWarnings("unchecked") + private static Map getFunctionArguments(String functionArguments) { + try { + return new ObjectMapper().readValue(functionArguments, Map.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + // Function which will be called by the AI model. + private String getWeatherInLocation(String location, String unit) { + + double temperature = 0; + + if (location.contains("Paris")) { + temperature = 15; + } + else if (location.contains("Tokyo")) { + temperature = 10; + } + else if (location.contains("San Francisco")) { + temperature = 30; + } + + return String.format("The weather in %s is %s%s", location, temperature, unit); + } + @Test void functionCall() throws JsonMappingException, JsonProcessingException { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); @@ -123,13 +133,13 @@ class OpenAiChatModelProxyToolCallsIT { do { - chatResponse = chatModel.call(prompt); + chatResponse = this.chatModel.call(prompt); // We will have to convert the chatResponse into OpenAI assistant message. // Note that the tool call check could be platform specific because the finish // reasons. - isToolCall = toolCallHelper.isToolCall(chatResponse, + isToolCall = this.toolCallHelper.isToolCall(chatResponse, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name())); @@ -166,8 +176,8 @@ class OpenAiChatModelProxyToolCallsIT { ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), - assistantMessage, toolMessageResponse); + List toolCallConversation = this.toolCallHelper + .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); assertThat(toolCallConversation).isNotEmpty(); @@ -187,7 +197,7 @@ class OpenAiChatModelProxyToolCallsIT { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); @@ -222,11 +232,11 @@ class OpenAiChatModelProxyToolCallsIT { private Flux processToolCall(Prompt prompt, Set finishReasons, Function customFunction) { - Flux chatResponses = chatModel.stream(prompt); + Flux chatResponses = this.chatModel.stream(prompt); return chatResponses.flatMap(chatResponse -> { - boolean isToolCall = toolCallHelper.isToolCall(chatResponse, finishReasons); + boolean isToolCall = this.toolCallHelper.isToolCall(chatResponse, finishReasons); if (isToolCall) { @@ -251,8 +261,8 @@ class OpenAiChatModelProxyToolCallsIT { ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - List toolCallConversation = toolCallHelper.buildToolCallConversation(prompt.getInstructions(), - assistantMessage, toolMessageResponse); + List toolCallConversation = this.toolCallHelper + .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); assertThat(toolCallConversation).isNotEmpty(); @@ -271,11 +281,11 @@ class OpenAiChatModelProxyToolCallsIT { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); - ChatResponse chatResponse = toolCallHelper.processCall(chatModel, prompt, + ChatResponse chatResponse = this.toolCallHelper.processCall(this.chatModel, prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { @@ -305,11 +315,11 @@ class OpenAiChatModelProxyToolCallsIT { List messages = List .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(this.functionDefinition)).build(); var prompt = new Prompt(messages, promptOptions); - Flux responses = toolCallHelper.processStream(chatModel, prompt, + Flux responses = this.toolCallHelper.processStream(this.chatModel, prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { @@ -340,16 +350,6 @@ class OpenAiChatModelProxyToolCallsIT { } - @SuppressWarnings("unchecked") - private static Map getFunctionArguments(String functionArguments) { - try { - return new ObjectMapper().readValue(functionArguments, Map.class); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - @SpringBootConfiguration static class Config { @@ -369,4 +369,4 @@ class OpenAiChatModelProxyToolCallsIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java index f1c5859d9..63bf6a88a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; - +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.converter.BeanOutputConverter; @@ -34,12 +40,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JacksonException; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -48,11 +49,23 @@ import com.fasterxml.jackson.databind.ObjectMapper; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiChatModelResponseFormatIT { + private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS); + private final Logger logger = LoggerFactory.getLogger(getClass()); @Autowired private OpenAiChatModel openAiChatModel; + public static boolean isValidJson(String json) { + try { + MAPPER.readTree(json); + } + catch (JacksonException e) { + return false; + } + return true; + } + @Test void jsonObject() throws JsonMappingException, JsonProcessingException { @@ -76,7 +89,7 @@ public class OpenAiChatModelResponseFormatIT { String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); assertThat(isValidJson(content)).isTrue(); } @@ -119,7 +132,7 @@ public class OpenAiChatModelResponseFormatIT { String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); assertThat(isValidJson(content)).isTrue(); } @@ -134,8 +147,11 @@ public class OpenAiChatModelResponseFormatIT { record Items(@JsonProperty(required = true, value = "explanation") String explanation, @JsonProperty(required = true, value = "output") String output) { + } + } + } var outputConverter = new BeanOutputConverter<>(MathReasoning.class); @@ -156,7 +172,7 @@ public class OpenAiChatModelResponseFormatIT { String content = response.getResult().getOutput().getContent(); - logger.info("Response content: {}", content); + this.logger.info("Response content: {}", content); MathReasoning mathReasoning = outputConverter.convert(content); @@ -165,18 +181,6 @@ public class OpenAiChatModelResponseFormatIT { assertThat(isValidJson(content)).isTrue(); } - private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS); - - public static boolean isValidJson(String json) { - try { - MAPPER.readTree(json); - } - catch (JacksonException e) { - return false; - } - return true; - } - @SpringBootConfiguration static class Config { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java index d222a64c3..443c65c79 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelTypeReferenceBeanOutputConverterIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -24,9 +25,9 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; @@ -44,14 +45,12 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { private static final Logger logger = LoggerFactory .getLogger(OpenAiChatModelTypeReferenceBeanOutputConverterIT.class); - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void typeRefOutputConverterRecords() { BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); String format = outputConverter.getFormat(); @@ -61,7 +60,7 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); List actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -77,6 +76,7 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); String format = outputConverter.getFormat(); @@ -87,7 +87,7 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = streamingChatModel.stream(prompt) + String generationTextFromStream = this.streamingChatModel.stream(prompt) .collectList() .block() .stream() @@ -106,4 +106,8 @@ class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT { assertThat(actorsFilms.get(1).movies()).hasSize(5); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index 4afeca476..2e6f6ddb6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.time.Duration; @@ -20,16 +21,16 @@ import java.time.Duration; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -65,7 +66,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -132,7 +133,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("/v1/chat/completions")) + this.server.expect(requestTo("/v1/chat/completions")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java index 4a4e30719..10645ee56 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; @@ -30,11 +37,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -48,7 +50,9 @@ public class OpenAiCompatibleChatModelIT { static OpenAiChatOptions forModelName(String modelName) { return OpenAiChatOptions.builder().withModel(modelName).build(); - }; + } + + ; static Stream openAiCompatibleApis() { Stream.Builder builder = Stream.builder(); @@ -72,7 +76,7 @@ public class OpenAiCompatibleChatModelIT { @ParameterizedTest @MethodSource("openAiCompatibleApis") void chatCompletion(ChatModel chatModel) { - Prompt prompt = new Prompt(conversation); + Prompt prompt = new Prompt(this.conversation); ChatResponse response = chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); @@ -82,7 +86,7 @@ public class OpenAiCompatibleChatModelIT { @ParameterizedTest @MethodSource("openAiCompatibleApis") void streamCompletion(StreamingChatModel streamingChatModel) { - Prompt prompt = new Prompt(conversation); + Prompt prompt = new Prompt(this.conversation); Flux flux = streamingChatModel.stream(prompt); List responses = flux.collectList().block(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 26635b2fd..97c38b13b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,6 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.Function; @@ -28,11 +26,13 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.OpenAiChatModel; @@ -48,7 +48,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Description; import org.springframework.core.ParameterizedTypeReference; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,55 +59,12 @@ public class OpenAiPaymentTransactionIT { private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class); + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + @Autowired ChatClient chatClient; - record TransactionStatusResponse(String id, String status) { - } - - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - advisedRequest = this.before(advisedRequest); - - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); - - this.observeAfter(advisedResponse); - - return advisedResponse; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "paymentStatus", "paymentStatuses" }) public void transactionPaymentStatuses(String functionName) { @@ -119,6 +76,7 @@ public class OpenAiPaymentTransactionIT { """) .call() .entity(new ParameterizedTypeReference>() { + }); logger.info("" + content); @@ -138,6 +96,7 @@ public class OpenAiPaymentTransactionIT { public void streamingPaymentStatuses(String functionName) { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); Flux flux = this.chatClient.prompt() @@ -166,20 +125,68 @@ public class OpenAiPaymentTransactionIT { assertThat(structure.get(2).status()).isEqualTo("rejected"); } + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; + } + + private AdvisedRequest before(AdvisedRequest request) { + this.logger.info("System text: \n" + request.systemText()); + this.logger.info("System params: " + request.systemParams()); + this.logger.info("User text: \n" + request.userText()); + this.logger.info("User params:" + request.userParams()); + this.logger.info("Function names: " + request.functionNames()); + + this.logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + this.logger.info("Response: " + advisedResponse.response()); + } + + } + record Transaction(String id) { + } record Status(String name) { + } record Transactions(List transactions) { + } record Statuses(List statuses) { - } - private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), - new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + } @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index bae33e60c..9b9e28e43 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat; import java.util.List; @@ -26,6 +27,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Flux; +import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; +import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -56,8 +59,6 @@ import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.openai.api.OpenAiImageApi.Data; import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageRequest; import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageResponse; -import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; -import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.core.io.ClassPathResource; @@ -69,8 +70,8 @@ import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.when; /** @@ -80,25 +81,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class OpenAiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -119,20 +101,22 @@ public class OpenAiRetryTests { @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, - OpenAiEmbeddingOptions.builder().build(), retryTemplate); - audioTranscriptionModel = new OpenAiAudioTranscriptionModel(openAiAudioApi, + this.chatModel = new OpenAiChatModel(this.openAiApi, OpenAiChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new OpenAiEmbeddingModel(this.openAiApi, MetadataMode.EMBED, + OpenAiEmbeddingOptions.builder().build(), this.retryTemplate); + this.audioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi, OpenAiAudioTranscriptionOptions.builder() .withModel("model") .withResponseFormat(TranscriptResponseFormat.JSON) .build(), - retryTemplate); - imageModel = new OpenAiImageModel(openAiImageApi, OpenAiImageOptions.builder().build(), retryTemplate); + this.retryTemplate); + this.imageModel = new OpenAiImageModel(this.openAiImageApi, OpenAiImageOptions.builder().build(), + this.retryTemplate); } @Test @@ -143,24 +127,24 @@ public class OpenAiRetryTests { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, new OpenAiApi.Usage(10, 10, 10)); - when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiChatNonTransientError() { - when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -172,25 +156,25 @@ public class OpenAiRetryTests { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null, null); - when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamNonTransientError() { - when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) + when(this.openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).subscribe()); } @Test @@ -199,23 +183,25 @@ public class OpenAiRetryTests { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new OpenAiApi.Usage(10, 10, 10)); - when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1")) + when(this.openAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiEmbeddingNonTransientError() { - when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.openAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -224,25 +210,25 @@ public class OpenAiRetryTests { var expectedResponse = new StructuredResponse("nl", 6.7f, "Transcription Text", List.of(), List.of()); - when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + when(this.openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - AudioTranscriptionResponse result = audioTranscriptionModel + AudioTranscriptionResponse result = this.audioTranscriptionModel .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac"))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(expectedResponse.text()); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiAudioTranscriptionNonTransientError() { - when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + when(this.openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) .thenThrow(new RuntimeException("Transient Error 1")); - assertThrows(RuntimeException.class, () -> audioTranscriptionModel + assertThrows(RuntimeException.class, () -> this.audioTranscriptionModel .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac")))); } @@ -251,25 +237,44 @@ public class OpenAiRetryTests { var expectedResponse = new OpenAiImageResponse(678l, List.of(new Data("url678", "b64", "prompt"))); - when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + when(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void openAiImageNonTransientError() { - when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + when(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 1af4b2065..6a3587252 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat.client; import java.io.IOException; @@ -62,9 +63,6 @@ class OpenAiChatClientIT extends AbstractIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { - } - @Test @Disabled("Although the Re2 advisor improves the response correctness it is not always guarantied to work.") void re2() { @@ -79,12 +77,12 @@ class OpenAiChatClientIT extends AbstractIT { """; // @formatter:off - ChatClient chatClient = ChatClient.builder(chatModel) + ChatClient chatClient = ChatClient.builder(this.chatModel) .defaultOptions(OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build()) .defaultUser(REASON_QUESTION) .build(); - + String response = chatClient.prompt() .advisors(new ReReadingAdvisor()) .call() @@ -101,9 +99,9 @@ class OpenAiChatClientIT extends AbstractIT { void call() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .advisors(new SimpleLoggerAdvisor()) - .system(s -> s.text(systemTextResource) + .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") @@ -119,7 +117,7 @@ class OpenAiChatClientIT extends AbstractIT { @Test void listOutputConverterString() { // @formatter:off - List collection = ChatClient.create(chatModel).prompt() + List collection = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -134,7 +132,7 @@ class OpenAiChatClientIT extends AbstractIT { void listOutputConverterBean() { // @formatter:off - List actorsFilms = ChatClient.create(chatModel).prompt() + List actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.") .call() .entity(new ParameterizedTypeReference>() { @@ -151,7 +149,7 @@ class OpenAiChatClientIT extends AbstractIT { var toStringListConverter = new ListOutputConverter(new DefaultConversionService()); // @formatter:off - List flavors = ChatClient.create(chatModel).prompt() + List flavors = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() @@ -166,7 +164,7 @@ class OpenAiChatClientIT extends AbstractIT { @Test void mapOutputConverter() { // @formatter:off - Map result = ChatClient.create(chatModel).prompt() + Map result = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("Provide me a List of {subject}") .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'")) .call() @@ -181,7 +179,7 @@ class OpenAiChatClientIT extends AbstractIT { void beanOutputConverter() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography for a random actor.") .call() .entity(ActorsFilms.class); @@ -195,7 +193,7 @@ class OpenAiChatClientIT extends AbstractIT { void beanOutputConverterRecords() { // @formatter:off - ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt() + ActorsFilms actorsFilms = ChatClient.create(this.chatModel).prompt() .user("Generate the filmography of 5 movies for Tom Hanks.") .call() .entity(ActorsFilms.class); @@ -212,7 +210,7 @@ class OpenAiChatClientIT extends AbstractIT { BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); // @formatter:off - Flux chatResponse = ChatClient.create(chatModel) + Flux chatResponse = ChatClient.create(this.chatModel) .prompt() .options(OpenAiChatOptions.builder().withStreamUsage(true).build()) .advisors(new SimpleLoggerAdvisor()) @@ -246,7 +244,7 @@ class OpenAiChatClientIT extends AbstractIT { void functionCallTest() { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .call() @@ -262,7 +260,7 @@ class OpenAiChatClientIT extends AbstractIT { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -278,7 +276,7 @@ class OpenAiChatClientIT extends AbstractIT { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -296,7 +294,7 @@ class OpenAiChatClientIT extends AbstractIT { void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) @@ -317,7 +315,7 @@ class OpenAiChatClientIT extends AbstractIT { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - String response = ChatClient.create(chatModel).prompt() + String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to .options(OpenAiChatOptions.builder().withModel(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) @@ -337,7 +335,7 @@ class OpenAiChatClientIT extends AbstractIT { URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) .build()) .user(u -> u.text("Explain what do you see on this picture?") @@ -353,4 +351,8 @@ class OpenAiChatClientIT extends AbstractIT { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilms(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 44ab8becf..e8a4fbb17 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.client; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.client; import java.lang.reflect.Method; import java.util.List; @@ -28,6 +27,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.openai.OpenAiTestConfiguration; @@ -40,7 +41,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; import org.springframework.test.context.ActiveProfiles; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -52,13 +53,21 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { @Value("classpath:/prompts/system-message.st") private Resource systemTextResource; - record ActorsFilms(String actor, List movies) { + public static Function createFunction(Object obj, Method method) { + return (T t) -> { + try { + return (R) method.invoke(obj, t); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }; } @Test void turnFunctionsOnAndOffTest() { - var chatClientBuilder = ChatClient.builder(chatModel); + var chatClientBuilder = ChatClient.builder(this.chatModel); // @formatter:off String response = chatClientBuilder.build().prompt() @@ -100,7 +109,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { void defaultFunctionCallTest() { // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -139,7 +148,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { }; // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .defaultToolContext(Map.of("sessionId", "123")) @@ -179,7 +188,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { }; // @formatter:off - String response = ChatClient.builder(chatModel) + String response = ChatClient.builder(this.chatModel) .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() @@ -197,7 +206,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) .stream() @@ -214,7 +223,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { @Test void functionCallWithExplicitInputType() throws NoSuchMethodException { - var chatClient = ChatClient.create(chatModel); + var chatClient = ChatClient.create(this.chatModel); Method currentTemp = MyFunction.class.getMethod("getCurrentTemp", MyFunction.Req.class); @@ -232,26 +241,20 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { assertThat(content).contains("23"); } - public static Function createFunction(Object obj, Method method) { - return (T t) -> { - try { - return (R) method.invoke(obj, t); - } - catch (Exception e) { - throw new RuntimeException(e); - } - }; + record ActorsFilms(String actor, List movies) { + } public static class MyFunction { - public record Req(String city) { - } - public String getCurrentTemp(Req req) { return "23"; } + public record Req(String city) { + + } + } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java index 47d3d2af7..0ad524689 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.chat.client; import java.util.HashMap; import java.util.Map; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -25,8 +28,6 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; -import reactor.core.publisher.Flux; - /** * Drawing inspiration from the human strategy of re-reading, this advisor implements a * re-reading strategy for LLM reasoning, dubbed RE2, to enhance understanding in the @@ -91,4 +92,4 @@ public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor return this; } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index c3cfb550a..00e126334 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,6 +31,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -62,7 +63,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+") @@ -85,10 +86,10 @@ class GroqWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -97,10 +98,10 @@ class GroqWithOpenAiChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -167,7 +168,7 @@ class GroqWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -186,15 +187,12 @@ class GroqWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -207,7 +205,7 @@ class GroqWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -228,7 +226,7 @@ class GroqWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -259,7 +257,7 @@ class GroqWithOpenAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -282,7 +280,7 @@ class GroqWithOpenAiChatModelIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -307,7 +305,7 @@ class GroqWithOpenAiChatModelIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -324,7 +322,7 @@ class GroqWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -340,7 +338,7 @@ class GroqWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel.stream(new Prompt(List.of(userMessage))); + Flux response = this.chatModel.stream(new Prompt(List.of(userMessage))); String content = response.collectList() .block() @@ -359,7 +357,7 @@ class GroqWithOpenAiChatModelIT { @ValueSource(strings = { "llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -374,6 +372,10 @@ class GroqWithOpenAiChatModelIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -389,4 +391,4 @@ class GroqWithOpenAiChatModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java index 9d7a66e95..51d29144d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,9 +31,10 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -45,6 +45,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; @@ -61,7 +62,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = MistralWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") @@ -83,10 +84,10 @@ class MistralWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -95,10 +96,10 @@ class MistralWithOpenAiChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -165,7 +166,7 @@ class MistralWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -184,15 +185,12 @@ class MistralWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -205,7 +203,7 @@ class MistralWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -226,7 +224,7 @@ class MistralWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -260,7 +258,7 @@ class MistralWithOpenAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -285,7 +283,7 @@ class MistralWithOpenAiChatModelIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -310,7 +308,7 @@ class MistralWithOpenAiChatModelIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -327,7 +325,7 @@ class MistralWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -343,7 +341,7 @@ class MistralWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel.stream(new Prompt(List.of(userMessage))); + Flux response = this.chatModel.stream(new Prompt(List.of(userMessage))); String content = response.collectList() .block() @@ -363,7 +361,7 @@ class MistralWithOpenAiChatModelIT { "open-mixtral-8x22b" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -378,6 +376,10 @@ class MistralWithOpenAiChatModelIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -393,4 +395,4 @@ class MistralWithOpenAiChatModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java index 4b713f220..4c5ad7bc3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.util.ArrayList; import java.util.Arrays; @@ -28,6 +27,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -54,7 +55,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -81,10 +82,10 @@ class NvidiaWithOpenAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -93,10 +94,10 @@ class NvidiaWithOpenAiChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -162,7 +163,7 @@ class NvidiaWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -181,15 +182,12 @@ class NvidiaWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -202,7 +200,7 @@ class NvidiaWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -223,7 +221,7 @@ class NvidiaWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -255,7 +253,7 @@ class NvidiaWithOpenAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -278,7 +276,7 @@ class NvidiaWithOpenAiChatModelIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -296,7 +294,7 @@ class NvidiaWithOpenAiChatModelIT { @Test void validateCallResponseMetadata() { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(DEFAULT_NVIDIA_MODEL).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -311,6 +309,10 @@ class NvidiaWithOpenAiChatModelIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -327,4 +329,4 @@ class NvidiaWithOpenAiChatModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index a723f9dae..523a3fa64 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.proxy; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.chat.proxy; import java.io.IOException; import java.net.URL; @@ -32,6 +31,11 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -61,11 +65,8 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Disabled("For manual smoke testing only.") @Testcontainers @@ -81,6 +82,12 @@ class OllamaWithOpenAiChatModelIT { static String baseUrl = "http://localhost:11434"; + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Autowired + private OpenAiChatModel chatModel; + @BeforeAll public static void beforeAll() throws IOException, InterruptedException { logger.info("Start pulling the '" + DEFAULT_OLLAMA_MODEL + " ' generative ... would take several minutes ..."); @@ -92,20 +99,14 @@ class OllamaWithOpenAiChatModelIT { baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434); } - @Value("classpath:/prompts/system-message.st") - private Resource systemResource; - - @Autowired - private OpenAiChatModel chatModel; - @Test void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -114,10 +115,10 @@ class OllamaWithOpenAiChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = chatModel.stream(prompt); + Flux flux = this.chatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -184,7 +185,7 @@ class OllamaWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -203,15 +204,12 @@ class OllamaWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.getActor()).isNotEmpty(); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -224,7 +222,7 @@ class OllamaWithOpenAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -245,7 +243,7 @@ class OllamaWithOpenAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -278,7 +276,7 @@ class OllamaWithOpenAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -302,7 +300,7 @@ class OllamaWithOpenAiChatModelIT { .build())) .build(); - Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -326,7 +324,7 @@ class OllamaWithOpenAiChatModelIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -343,7 +341,7 @@ class OllamaWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -360,7 +358,7 @@ class OllamaWithOpenAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = chatModel + Flux response = this.chatModel .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); String content = response.collectList() @@ -380,7 +378,7 @@ class OllamaWithOpenAiChatModelIT { @ValueSource(strings = { "mistral" }) void validateCallResponseMetadata(String model) { // @formatter:off - ChatResponse response = ChatClient.create(chatModel).prompt() + ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() @@ -395,6 +393,10 @@ class OllamaWithOpenAiChatModelIT { assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration static class Config { @@ -410,4 +412,4 @@ class OllamaWithOpenAiChatModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 079990e1a..ae3f4dbb0 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.embedding; -import org.junit.jupiter.api.Test; +import java.nio.charset.StandardCharsets; +import java.util.List; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; @@ -33,9 +36,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; -import java.nio.charset.StandardCharsets; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -50,9 +50,9 @@ class EmbeddingIT extends AbstractIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); @@ -60,12 +60,12 @@ class EmbeddingIT extends AbstractIT { assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(2); assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(2); - assertThat(embeddingModel.dimensions()).isEqualTo(1536); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1536); } @Test void embeddingBatchDocuments() throws Exception { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); List embedded = this.embeddingModel.embed( List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")), OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), @@ -76,10 +76,10 @@ class EmbeddingIT extends AbstractIT { @Test void embeddingBatchDocumentsThatExceedTheLimit() throws Exception { - assertThat(embeddingModel).isNotNull(); - String contentAsString = resource.getContentAsString(StandardCharsets.UTF_8); + assertThat(this.embeddingModel).isNotNull(); + String contentAsString = this.resource.getContentAsString(StandardCharsets.UTF_8); assertThatThrownBy(() -> { - embeddingModel.embed(List.of(new Document("Hello World"), new Document(contentAsString)), + this.embeddingModel.embed(List.of(new Document("Hello World"), new Document(contentAsString)), OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), new TokenCountBatchingStrategy()); }).isInstanceOf(IllegalArgumentException.class); @@ -88,7 +88,7 @@ class EmbeddingIT extends AbstractIT { @Test void embedding3Large() { - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-large").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -103,7 +103,7 @@ class EmbeddingIT extends AbstractIT { @Test void textEmbeddingAda002() { - EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"), OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-small").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java index 6a3e4d036..f5f0b0462 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -66,13 +68,13 @@ public class OpenAiEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java index 6f0ea968a..cd500e2d7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; import org.assertj.core.api.Assertions; @@ -44,7 +45,7 @@ public class OpenAiImageModelIT extends AbstractIT { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java index 0a1d3087d..146504420 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; @@ -66,10 +68,10 @@ public class OpenAiImageModelObservationIT { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java index a3b3160f1..47f8b60e5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelWithImageResponseMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.image; +import java.util.List; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; @@ -34,12 +38,10 @@ import org.springframework.http.MediaType; import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -60,7 +62,7 @@ public class OpenAiImageModelWithImageResponseMetadataTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -102,7 +104,7 @@ public class OpenAiImageModelWithImageResponseMetadataTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("v1/images/generations")) + this.server.expect(requestTo("v1/images/generations")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 1c7c53e0d..65a97c1c6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -1,6 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.metadata; import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.api.OpenAiApi; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java index 050b05a53..dc2aff589 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.metadata.support; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.openai.metadata.support; import java.time.Duration; @@ -23,6 +22,8 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor.DurationFormatter; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiHttpResponseHeadersInterceptor}. * diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java index ed0658862..f00aa7813 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.moderation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.moderation.*; + +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationOptionsBuilder; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.boot.test.context.SpringBootTest; @@ -42,7 +50,7 @@ public class OpenAiModerationModelIT extends AbstractIT { ModerationPrompt moderationPrompt = new ModerationPrompt(instructions, options); - ModerationResponse moderationResponse = openAiModerationModel.call(moderationPrompt); + ModerationResponse moderationResponse = this.openAiModerationModel.call(moderationPrompt); assertThat(moderationResponse.getResults()).hasSize(1); @@ -96,7 +104,7 @@ public class OpenAiModerationModelIT extends AbstractIT { ModerationPrompt moderationPrompt = new ModerationPrompt(instructions, options); - ModerationResponse moderationResponse = openAiModerationModel.call(moderationPrompt); + ModerationResponse moderationResponse = this.openAiModerationModel.call(moderationPrompt); assertThat(moderationResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java index 9d2fd6877..30d2a9c27 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/OpenAiModerationModelTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,18 @@ package org.springframework.ai.openai.moderation; +import java.util.List; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.moderation.*; + +import org.springframework.ai.moderation.Categories; +import org.springframework.ai.moderation.CategoryScores; +import org.springframework.ai.moderation.Generation; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResponse; +import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.OpenAiModerationModel; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; @@ -33,10 +42,10 @@ import org.springframework.http.MediaType; import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.RestClient; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.header; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** @@ -56,7 +65,7 @@ public class OpenAiModerationModelTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -121,7 +130,7 @@ public class OpenAiModerationModelTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - server.expect(requestTo("v1/moderations")) + this.server.expect(requestTo("v1/moderations")) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 944852435..09d7c42f7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,14 +22,13 @@ import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; - import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.openai.OpenAiAudioSpeechModel; @@ -88,23 +87,23 @@ public abstract class AbstractIT { String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); - PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + PromptTemplate userPromptTemplate = new PromptTemplate(this.userEvaluatorResource, Map.of("question", question, "answer", answer)); SystemMessage systemMessage; if (factBased) { - systemMessage = new SystemMessage(qaEvaluatorFactBasedAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorFactBasedAnswerResource); } else { - systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorAccurateAnswerResource); } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = chatModel.call(prompt).getResult().getOutput().getContent(); + String yesOrNo = this.chatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { - SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); + SystemMessage notRelatedSystemMessage = new SystemMessage(this.qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = chatModel.call(prompt).getResult().getOutput().getContent(); + String reasonForFailure = this.chatModel.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index bb23600ad..d840cf6a7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai.transformer; import java.io.IOException; @@ -73,7 +74,7 @@ public class MetadataTransformerIT { @Test public void testKeywordExtractor() { - var updatedDocuments = keywordMetadataEnricher.apply(List.of(document1, document2)); + var updatedDocuments = this.keywordMetadataEnricher.apply(List.of(this.document1, this.document2)); List> keywords = updatedDocuments.stream().map(d -> d.getMetadata()).toList(); @@ -91,7 +92,7 @@ public class MetadataTransformerIT { @Test public void testSummaryExtractor() { - var updatedDocuments = summaryMetadataEnricher.apply(List.of(document1, document2)); + var updatedDocuments = this.summaryMetadataEnricher.apply(List.of(this.document1, this.document2)); List> summaries = updatedDocuments.stream().map(d -> d.getMetadata()).toList(); @@ -115,34 +116,34 @@ public class MetadataTransformerIT { @Test public void testContentFormatEnricher() { - assertThat(((DefaultContentFormatter) document1.getContentFormatter()).getExcludedEmbedMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document1.getContentFormatter()).getExcludedEmbedMetadataKeys()) .doesNotContain("NewEmbedKey"); - assertThat(((DefaultContentFormatter) document1.getContentFormatter()).getExcludedInferenceMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document1.getContentFormatter()).getExcludedInferenceMetadataKeys()) .doesNotContain("NewInferenceKey"); - assertThat(((DefaultContentFormatter) document2.getContentFormatter()).getExcludedEmbedMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document2.getContentFormatter()).getExcludedEmbedMetadataKeys()) .doesNotContain("NewEmbedKey"); - assertThat(((DefaultContentFormatter) document2.getContentFormatter()).getExcludedInferenceMetadataKeys()) + assertThat(((DefaultContentFormatter) this.document2.getContentFormatter()).getExcludedInferenceMetadataKeys()) .doesNotContain("NewInferenceKey"); - List enrichedDocuments = contentFormatTransformer.apply(List.of(document1, document2)); + List enrichedDocuments = this.contentFormatTransformer.apply(List.of(this.document1, this.document2)); assertThat(enrichedDocuments.size()).isEqualTo(2); var doc1 = enrichedDocuments.get(0); var doc2 = enrichedDocuments.get(1); - assertThat(doc1).isEqualTo(document1); - assertThat(doc2).isEqualTo(document2); + assertThat(doc1).isEqualTo(this.document1); + assertThat(doc2).isEqualTo(this.document2); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getTextTemplate()) - .isSameAs(defaultContentFormatter.getTextTemplate()); + .isSameAs(this.defaultContentFormatter.getTextTemplate()); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getExcludedEmbedMetadataKeys()) .contains("NewEmbedKey"); assertThat(((DefaultContentFormatter) doc1.getContentFormatter()).getExcludedInferenceMetadataKeys()) .contains("NewInferenceKey"); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getTextTemplate()) - .isSameAs(defaultContentFormatter.getTextTemplate()); + .isSameAs(this.defaultContentFormatter.getTextTemplate()); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getExcludedEmbedMetadataKeys()) .contains("NewEmbedKey"); assertThat(((DefaultContentFormatter) doc2.getContentFormatter()).getExcludedInferenceMetadataKeys()) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java index 21ca5bc49..f58063ccd 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/vectorstore/SimplePersistentVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,49 +13,51 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.vectorstore; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.CleanupMode; -import org.junit.jupiter.api.io.TempDir; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.vectorstore.SimpleVectorStore; -import org.springframework.ai.reader.JsonMetadataGenerator; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.core.io.Resource; +package org.springframework.ai.openai.vectorstore; import java.io.File; import java.nio.file.Path; import java.util.List; import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.CleanupMode; +import org.junit.jupiter.api.io.TempDir; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.reader.JsonMetadataGenerator; +import org.springframework.ai.reader.JsonReader; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest public class SimplePersistentVectorStoreIT { + @TempDir(cleanup = CleanupMode.ON_SUCCESS) + Path workingDir; + @Value("file:src/test/resources/data/acme/bikes.json") private Resource bikesJsonResource; @Autowired private EmbeddingModel embeddingModel; - @TempDir(cleanup = CleanupMode.ON_SUCCESS) - Path workingDir; - @Test void persist() { - JsonReader jsonReader = new JsonReader(bikesJsonResource, new ProductMetadataGenerator(), "price", "name", + JsonReader jsonReader = new JsonReader(this.bikesJsonResource, new ProductMetadataGenerator(), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); SimpleVectorStore vectorStore = new SimpleVectorStore(this.embeddingModel); vectorStore.add(documents); - File tempFile = new File(workingDir.toFile(), "temp.txt"); + File tempFile = new File(this.workingDir.toFile(), "temp.txt"); vectorStore.save(tempFile); assertThat(tempFile).isNotEmpty(); assertThat(tempFile).content().contains("Velo 99 XR1 AXS"); diff --git a/models/spring-ai-openai/src/test/resources/application-logging-test.properties b/models/spring-ai-openai/src/test/resources/application-logging-test.properties index 8e8b3b2c3..4466a7180 100644 --- a/models/spring-ai-openai/src/test/resources/application-logging-test.properties +++ b/models/spring-ai-openai/src/test/resources/application-logging-test.properties @@ -1 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + logging.level.org.springframework.ai.chat.client.advisor=DEBUG diff --git a/models/spring-ai-postgresml/pom.xml b/models/spring-ai-postgresml/pom.xml index acf8349f4..0312ebd4f 100644 --- a/models/spring-ai-postgresml/pom.xml +++ b/models/spring-ai-postgresml/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java index 80ddb93f2..14ad9d41c 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.sql.Array; @@ -54,34 +55,6 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements private final JdbcTemplate jdbcTemplate; - public enum VectorType { - - PG_ARRAY("", null, (rs, i) -> { - Array embedding = rs.getArray("embedding"); - return EmbeddingUtils.toPrimitive((Float[]) embedding.getArray()); - - }), - - PG_VECTOR("::vector", "vector", (rs, i) -> { - String embedding = rs.getString("embedding"); - return EmbeddingUtils.toPrimitive(Arrays.stream((embedding.substring(1, embedding.length() - 1) - /* remove leading '[' and trailing ']' */.split(","))).map(Float::parseFloat).toList()); - }); - - private final String cast; - - private final String extensionName; - - private final RowMapper rowMapper; - - VectorType(String cast, String extensionName, RowMapper rowMapper) { - this.cast = cast; - this.extensionName = extensionName; - this.rowMapper = rowMapper; - } - - } - /** * a constructor * @param jdbcTemplate JdbcTemplate @@ -237,4 +210,32 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements } } + public enum VectorType { + + PG_ARRAY("", null, (rs, i) -> { + Array embedding = rs.getArray("embedding"); + return EmbeddingUtils.toPrimitive((Float[]) embedding.getArray()); + + }), + + PG_VECTOR("::vector", "vector", (rs, i) -> { + String embedding = rs.getString("embedding"); + return EmbeddingUtils.toPrimitive(Arrays.stream((embedding.substring(1, embedding.length() - 1) + /* remove leading '[' and trailing ']' */.split(","))).map(Float::parseFloat).toList()); + }); + + private final String cast; + + private final String extensionName; + + private final RowMapper rowMapper; + + VectorType(String cast, String extensionName, RowMapper rowMapper) { + this.cast = cast; + this.extensionName = extensionName; + this.rowMapper = rowMapper; + } + + } + } diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java index 1077141b6..265045691 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.util.Map; @@ -61,6 +62,50 @@ public class PostgresMlEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + public String getTransformer() { + return this.transformer; + } + + public void setTransformer(String transformer) { + this.transformer = transformer; + } + + public VectorType getVectorType() { + return this.vectorType; + } + + public void setVectorType(VectorType vectorType) { + this.vectorType = vectorType; + } + + public Map getKwargs() { + return this.kwargs; + } + + public void setKwargs(Map kwargs) { + this.kwargs = kwargs; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + + @Override + @JsonIgnore + public String getModel() { + return null; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected PostgresMlEmbeddingOptions options; @@ -100,48 +145,4 @@ public class PostgresMlEmbeddingOptions implements EmbeddingOptions { } - public String getTransformer() { - return this.transformer; - } - - public void setTransformer(String transformer) { - this.transformer = transformer; - } - - public VectorType getVectorType() { - return this.vectorType; - } - - public void setVectorType(VectorType vectorType) { - this.vectorType = vectorType; - } - - public Map getKwargs() { - return this.kwargs; - } - - public void setKwargs(Map kwargs) { - this.kwargs = kwargs; - } - - public MetadataMode getMetadataMode() { - return metadataMode; - } - - public void setMetadataMode(MetadataMode metadataMode) { - this.metadataMode = metadataMode; - } - - @Override - @JsonIgnore - public String getModel() { - return null; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java index 23697f934..64627bc47 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.time.Duration; @@ -26,13 +27,6 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingRequest; -import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType; - import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; import org.testcontainers.junit.jupiter.Container; @@ -41,6 +35,11 @@ import org.testcontainers.utility.DockerImageName; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase; @@ -257,4 +256,4 @@ class PostgresMlEmbeddingModelIT { } -} \ No newline at end of file +} diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java index 07ce531b7..c0464867c 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.postgresml; import java.util.Map; diff --git a/models/spring-ai-qianfan/pom.xml b/models/spring-ai-qianfan/pom.xml index 39ea55970..379d29eb2 100644 --- a/models/spring-ai-qianfan/pom.xml +++ b/models/spring-ai-qianfan/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index 7a9448ef6..aaf68884c 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyUsage; @@ -48,12 +56,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.List; -import java.util.Map; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal QianFan} @@ -71,16 +73,16 @@ public class QianFanChatModel implements ChatModel, StreamingChatModel { private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - /** - * The default options used for the chat completion requests. - */ - private final QianFanChatOptions defaultOptions; - /** * The retry template used to retry the QianFan API calls. */ public final RetryTemplate retryTemplate; + /** + * The default options used for the chat completion requests. + */ + private final QianFanChatOptions defaultOptions; + /** * Low-level access to the QianFan API. */ diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java index 24164ab7f..24bff760b 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.qianfan.api.QianFanApi; import org.springframework.boot.context.properties.NestedConfigurationProperty; -import java.util.List; - /** * QianFanChatOptions represents the options for performing chat completion using the * QianFan API. It provides methods to set and retrieve various options like model, @@ -85,62 +87,17 @@ public class QianFanChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - protected QianFanChatOptions options; - - public Builder() { - this.options = new QianFanChatOptions(); - } - - public Builder(QianFanChatOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withFrequencyPenalty(Double frequencyPenalty) { - this.options.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.options.maxTokens = maxTokens; - return this; - } - - public Builder withPresencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } - - public Builder withResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withStop(List stop) { - this.options.stop = stop; - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public QianFanChatOptions build() { - return this.options; - } - + public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { + return QianFanChatOptions.builder() + .withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .build(); } @Override @@ -234,74 +191,93 @@ public class QianFanChatOptions implements ChatOptions { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); return result; } @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } QianFanChatOptions other = (QianFanChatOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) + if (other.frequencyPenalty != null) { return false; + } } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { return false; + } if (this.maxTokens == null) { - if (other.maxTokens != null) + if (other.maxTokens != null) { return false; + } } - else if (!this.maxTokens.equals(other.maxTokens)) + else if (!this.maxTokens.equals(other.maxTokens)) { return false; + } if (this.presencePenalty == null) { - if (other.presencePenalty != null) + if (other.presencePenalty != null) { return false; + } } - else if (!this.presencePenalty.equals(other.presencePenalty)) + else if (!this.presencePenalty.equals(other.presencePenalty)) { return false; + } if (this.responseFormat == null) { - if (other.responseFormat != null) + if (other.responseFormat != null) { return false; + } } - else if (!this.responseFormat.equals(other.responseFormat)) + else if (!this.responseFormat.equals(other.responseFormat)) { return false; + } if (this.stop == null) { - if (other.stop != null) + if (other.stop != null) { return false; + } } - else if (!stop.equals(other.stop)) + else if (!this.stop.equals(other.stop)) { return false; + } if (this.temperature == null) { - if (other.temperature != null) + if (other.temperature != null) { return false; + } } - else if (!this.temperature.equals(other.temperature)) + else if (!this.temperature.equals(other.temperature)) { return false; + } if (this.topP == null) { - if (other.topP != null) + if (other.topP != null) { return false; + } } - else if (!topP.equals(other.topP)) + else if (!this.topP.equals(other.topP)) { return false; + } return true; } @@ -310,17 +286,62 @@ public class QianFanChatOptions implements ChatOptions { return fromOptions(this); } - public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { - return QianFanChatOptions.builder() - .withModel(fromOptions.getModel()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withResponseFormat(fromOptions.getResponseFormat()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .build(); + public static class Builder { + + protected QianFanChatOptions options; + + public Builder() { + this.options = new QianFanChatOptions(); + } + + public Builder(QianFanChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public QianFanChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java index 40323cf62..f681cac1c 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -40,8 +44,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * QianFan Embedding Client implementation. * diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java index 672b68ab2..60700cff2 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -48,6 +50,29 @@ public class QianFanEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected QianFanEmbeddingOptions options; @@ -72,27 +97,4 @@ public class QianFanEmbeddingOptions implements EmbeddingOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public String getUser() { - return user; - } - - public void setUser(String user) { - this.user = user; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java index de4b7e26f..ba2ca408f 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.List; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -37,8 +41,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * QianFanImageModel is a class that implements the ImageModel interface. It provides a * client for calling the QianFan image generation API. diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java index 7ddbd7013..d102d34fe 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.image.ImageOptions; -import java.util.Objects; +import org.springframework.ai.image.ImageOptions; /** * QianFan Image API options. QianFanImageOptions.java @@ -88,50 +90,6 @@ public class QianFanImageOptions implements ImageOptions { return new Builder(); } - public static class Builder { - - private final QianFanImageOptions options; - - private Builder() { - this.options = new QianFanImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withStyle(String style) { - options.setStyle(style); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public QianFanImageOptions build() { - return options; - } - - } - @Override public Integer getN() { return this.n; @@ -206,24 +164,72 @@ public class QianFanImageOptions implements ImageOptions { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof QianFanImageOptions that)) + } + if (!(o instanceof QianFanImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(size, that.size) - && Objects.equals(style, that.style) && Objects.equals(user, that.user); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.size, that.size) && Objects.equals(this.style, that.style) + && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, size, style, user); + return Objects.hash(this.n, this.model, this.width, this.height, this.size, this.style, this.user); } @Override public String toString() { - return "QianFanImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" - + height + ", size='" + size + '\'' + ", style='" + style + '\'' + ", user='" + user + '\'' + '}'; + return "QianFanImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", size='" + this.size + '\'' + ", style='" + this.style + '\'' + + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final QianFanImageOptions options; + + private Builder() { + this.options = new QianFanImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public QianFanImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java index a72059161..2538e4f8b 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.aot; import org.springframework.ai.qianfan.api.QianFanApi; diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java index 5a680338f..da93b16b6 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.function.Predicate; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.qianfan.api.auth.AuthApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.ParameterizedTypeReference; @@ -27,11 +34,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.List; -import java.util.function.Predicate; // @formatter:off /** @@ -125,6 +127,70 @@ public class QianFanApi extends AuthApi { .build(); } + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(ChatCompletionChunk.class) + .takeUntil(SSE_DONE_PREDICATE); + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + // The input must not an empty string, and any array must be 16 dimensions or + // less. + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/embeddings/{model}?access_token={token}", embeddingRequest.model, getAccessToken()) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + + }); + } + /** * QianFan Chat Completion Models: * QianFan Model. @@ -157,7 +223,44 @@ public class QianFanApi extends AuthApi { } public String getValue() { - return value; + return this.value; + } + } + + /** + * QianFan Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 384 + */ + EMBEDDING_V1("embedding-v1"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_ZH("bge_large_zh"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_EN("bge_large_en"), + + /** + * DIMENSION: 1024 + */ + TAO_8K("tao_8k"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; } } @@ -348,79 +451,6 @@ public class QianFanApi extends AuthApi { ) { } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - /** - * Creates a streaming chat response for the given chat conversation. - * @param chatRequest The chat completion request. Must have the stream property set - * to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - return this.webClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(ChatCompletionChunk.class) - .takeUntil(SSE_DONE_PREDICATE); - } - - /** - * QianFan Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 384 - */ - EMBEDDING_V1("embedding-v1"), - - /** - * DIMENSION: 1024 - */ - BGE_LARGE_ZH("bge_large_zh"), - - /** - * DIMENSION: 1024 - */ - BGE_LARGE_EN("bge_large_en"), - - /** - * DIMENSION: 1024 - */ - TAO_8K("tao_8k"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Creates an embedding vector representing the input text. * @@ -502,6 +532,7 @@ public class QianFanApi extends AuthApi { public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } + } /** @@ -524,32 +555,5 @@ public class QianFanApi extends AuthApi { // @formatter:on } - /** - * Creates an embedding vector representing the input text or token array. - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - */ - public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple - // inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.texts(), "The input can not be null."); - - // The input must not an empty string, and any array must be 16 dimensions or - // less. - Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); - Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); - - return this.restClient.post() - .uri("/v1/wenxinworkshop/embeddings/{model}?access_token={token}", embeddingRequest.model, getAccessToken()) - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java index b269500a4..5dd2744f7 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java index 2fb20942f..2532e52df 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.qianfan.api.auth.AuthApi; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -24,8 +28,6 @@ import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import java.util.List; - /** * QianFan Image API. * @@ -76,6 +78,18 @@ public class QianFanImageApi extends AuthApi { .build(); } + public ResponseEntity createImage(QianFanImageRequest qianFanImageRequest) { + Assert.notNull(qianFanImageRequest, "Image request cannot be null."); + Assert.hasLength(qianFanImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/text2image/{model}?access_token={token}", qianFanImageRequest.model(), + getAccessToken()) + .body(qianFanImageRequest) + .retrieve() + .toEntity(QianFanImageResponse.class); + } + /** * QianFan Image API model. */ @@ -122,24 +136,11 @@ public class QianFanImageApi extends AuthApi { @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("index") Integer index, - @JsonProperty("b64_image") String b64Image) { - } // @formatter:onn - public ResponseEntity createImage(QianFanImageRequest qianFanImageRequest) { - Assert.notNull(qianFanImageRequest, "Image request cannot be null."); - Assert.hasLength(qianFanImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("index") Integer index, @JsonProperty("b64_image") String b64Image) { - return this.restClient.post() - .uri("/v1/wenxinworkshop/text2image/{model}?access_token={token}", qianFanImageRequest.model(), - getAccessToken()) - .body(qianFanImageRequest) - .retrieve() - .toEntity(QianFanImageResponse.class); } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java index f8668c97f..fb1e9723b 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java @@ -1,10 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api; +import java.util.function.Consumer; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import java.util.function.Consumer; - public class QianFanUtils { public static Consumer defaultHeaders() { diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java index 8681343af..96070de8c 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; import com.fasterxml.jackson.annotation.JsonProperty; @@ -12,4 +28,5 @@ public record AccessTokenResponse(@JsonProperty("access_token") String accessTok @JsonProperty("refresh_token") String refreshToken, @JsonProperty("expires_in") Long expiresIn, @JsonProperty("session_key") String sessionKey, @JsonProperty("session_secret") String sessionSecret, @JsonProperty("error") String error, @JsonProperty("error_description") String errorDescription, String scope) { + } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java index 648e61fd3..b265c8ce8 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AuthApi.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; /** diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java index 0c3e20f2c..ec29676eb 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; /** @@ -31,35 +47,35 @@ public class QianFanAccessToken { this.sessionKey = accessTokenResponse.sessionKey(); this.sessionSecret = accessTokenResponse.sessionSecret(); this.scope = accessTokenResponse.scope(); - this.refreshTime = getCurrentTimeInSeconds() + (long) ((double) expiresIn * FRACTION_OF_TIME_TO_LIVE); + this.refreshTime = getCurrentTimeInSeconds() + (long) ((double) this.expiresIn * FRACTION_OF_TIME_TO_LIVE); } public String getAccessToken() { - return accessToken; + return this.accessToken; } public String getRefreshToken() { - return refreshToken; + return this.refreshToken; } public Long getExpiresIn() { - return expiresIn; + return this.expiresIn; } public String getSessionKey() { - return sessionKey; + return this.sessionKey; } public String getSessionSecret() { - return sessionSecret; + return this.sessionSecret; } public Long getRefreshTime() { - return refreshTime; + return this.refreshTime; } public String getScope() { - return scope; + return this.scope; } public synchronized boolean needsRefresh() { diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java index b9af294ff..d92ac2b59 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.qianfan.api.auth; import org.springframework.http.ResponseEntity; @@ -28,9 +44,13 @@ public class QianFanAuthenticator { this.restClient = RestClient.builder().baseUrl(authUrl).build(); } + public static Builder builder() { + return new Builder(); + } + public QianFanAccessToken requestToken() { ResponseEntity tokenResponseEntity = this.restClient.get() - .uri(OPERATION_PATH, apiKey, secretKey) + .uri(OPERATION_PATH, this.apiKey, this.secretKey) .retrieve() .toEntity(AccessTokenResponse.class); AccessTokenResponse tokenResponse = tokenResponseEntity.getBody(); @@ -63,13 +83,9 @@ public class QianFanAuthenticator { } public QianFanAuthenticator build() { - return new QianFanAuthenticator(DEFAULT_AUTH_URL, apiKey, secretKey); + return new QianFanAuthenticator(DEFAULT_AUTH_URL, this.apiKey, this.secretKey); } } - public static Builder builder() { - return new Builder(); - } - } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java index 6b5921ec9..eaa69e755 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -26,10 +27,6 @@ import org.springframework.util.Assert; */ public class QianFanUsage implements Usage { - public static QianFanUsage from(QianFanApi.Usage usage) { - return new QianFanUsage(usage); - } - private final QianFanApi.Usage usage; protected QianFanUsage(QianFanApi.Usage usage) { @@ -37,6 +34,10 @@ public class QianFanUsage implements Usage { this.usage = usage; } + public static QianFanUsage from(QianFanApi.Usage usage) { + return new QianFanUsage(usage); + } + protected QianFanApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java index c4f76a402..de3bf0400 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.qianfan.api.QianFanApi; diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java index be98b6819..d1c84192d 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java index 38f34b72c..f8dae1f20 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.Objects; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.stringtemplate.v4.ST; +import reactor.core.publisher.Flux; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk; @@ -26,11 +33,6 @@ import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; import org.springframework.http.ResponseEntity; -import org.stringtemplate.v4.ST; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; import static org.assertj.core.api.Assertions.assertThat; @@ -46,7 +48,7 @@ public class QianFanApiIT { @Test void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = qianFanApi.chatCompletionEntity(new ChatCompletionRequest( + ResponseEntity response = this.qianFanApi.chatCompletionEntity(new ChatCompletionRequest( List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7, false)); assertThat(response).isNotNull(); @@ -56,7 +58,7 @@ public class QianFanApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = qianFanApi.chatCompletionStream(new ChatCompletionRequest( + Flux response = this.qianFanApi.chatCompletionStream(new ChatCompletionRequest( List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7, true)); assertThat(response).isNotNull(); @@ -65,7 +67,8 @@ public class QianFanApiIT { @Test void embeddings() { - ResponseEntity response = qianFanApi.embeddings(new QianFanApi.EmbeddingRequest("Hello world")); + ResponseEntity response = this.qianFanApi + .embeddings(new QianFanApi.EmbeddingRequest("Hello world")); assertThat(response).isNotNull(); assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1); diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java index e59baeccd..978eb7216 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.api; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -47,11 +54,6 @@ import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -64,25 +66,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class QianFanRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private @Mock QianFanApi qianFanApi; @@ -98,13 +81,14 @@ public class QianFanRetryTests { @BeforeEach public void beforeEach() { RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); - chatClient = new QianFanChatModel(qianFanApi, QianFanChatOptions.builder().build(), retryTemplate); - embeddingClient = new QianFanEmbeddingModel(qianFanApi, MetadataMode.EMBED, + this.chatClient = new QianFanChatModel(this.qianFanApi, QianFanChatOptions.builder().build(), retryTemplate); + this.embeddingClient = new QianFanEmbeddingModel(this.qianFanApi, MetadataMode.EMBED, QianFanEmbeddingOptions.builder().build(), retryTemplate); - imageModel = new QianFanImageModel(qianFanImageApi, QianFanImageOptions.builder().build(), retryTemplate); + this.imageModel = new QianFanImageModel(this.qianFanImageApi, QianFanImageOptions.builder().build(), + retryTemplate); } @Test @@ -112,24 +96,24 @@ public class QianFanRetryTests { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 666L, "Response", "STOP", new Usage(10, 10, 10)); - when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatClient.call(new Prompt("text")); + var result = this.chatClient.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanChatNonTransientError() { - when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatClient.call(new Prompt("text"))); } @Test @@ -138,25 +122,25 @@ public class QianFanRetryTests { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion", 666L, "Response", "", true, null); - when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatClient.stream(new Prompt("text")); + var result = this.chatClient.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(Objects.requireNonNull(result.collectList().block()).get(0).getResult().getOutput().getContent()) .isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanChatStreamNonTransientError() { - when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatClient.stream(new Prompt("text")).collectList().block()); } @Test @@ -165,24 +149,25 @@ public class QianFanRetryTests { EmbeddingList expectedEmbeddings = new EmbeddingList("embedding_list", List.of(embedding), "model", null, null, new Usage(10, 10, 10)); - when(qianFanApi.embeddings(isA(EmbeddingRequest.class))) + when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingClient + var result = this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanEmbeddingNonTransientError() { - when(qianFanApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingClient + when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -191,25 +176,44 @@ public class QianFanRetryTests { var expectedResponse = new QianFanImageResponse("1", 678L, List.of(new Data(1, "b64"))); - when(qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getB64Json()).isEqualTo("b64"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void qianFanImageNonTransientError() { - when(qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java index 1ed3d9d1a..46c4c67e5 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.chat; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -32,11 +39,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -61,10 +63,10 @@ class QianFanChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); } @@ -73,10 +75,10 @@ class QianFanChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -91,4 +93,4 @@ class QianFanChatModelIT { assertThat(stitchedResponseContent).contains("Blackbeard"); } -} \ No newline at end of file +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java index cd70b8365..4b447ffa3 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -62,7 +64,7 @@ public class QianFanChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -80,7 +82,7 @@ public class QianFanChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -103,7 +105,7 @@ public class QianFanChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -123,7 +125,7 @@ public class QianFanChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java index 39a53d421..371d5c2a8 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.embedding; +import java.util.List; + import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.qianfan.QianFanTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -43,23 +44,23 @@ class EmbeddingIT { @Test void defaultEmbedding() { - Assertions.assertThat(embeddingModel).isNotNull(); + Assertions.assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); - Assertions.assertThat(embeddingModel.dimensions()).isEqualTo(1024); + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - Assertions.assertThat(embeddingModel).isNotNull(); + Assertions.assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -69,7 +70,7 @@ class EmbeddingIT { assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024); - Assertions.assertThat(embeddingModel.dimensions()).isEqualTo(1024); + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java index c143a63ec..5061626a7 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -36,8 +40,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -66,13 +68,13 @@ public class QianFanEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java index 32ea9a439..6e4be44c7 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.image; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptionsBuilder; @@ -50,7 +52,7 @@ public class QianFanImageModelIT { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java index 6dbab1452..3ddaf41e5 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.qianfan.image; import io.micrometer.observation.tck.TestObservationRegistry; @@ -20,6 +21,7 @@ import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; + import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; @@ -67,10 +69,10 @@ public class QianFanImageModelObservationIT { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-stability-ai/pom.xml b/models/spring-ai-stability-ai/pom.xml index d45acec7d..d60bc4433 100644 --- a/models/spring-ai-stability-ai/pom.xml +++ b/models/spring-ai-stability-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java index 648dae1a5..7626b1792 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; -import org.springframework.ai.image.ImageGenerationMetadata; - import java.util.Objects; +import org.springframework.ai.image.ImageGenerationMetadata; + /** * Represents metadata associated with the image generation process in the StabilityAI * framework. @@ -50,10 +51,12 @@ public class StabilityAiImageGenerationMetadata implements ImageGenerationMetada @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof StabilityAiImageGenerationMetadata that)) + } + if (!(o instanceof StabilityAiImageGenerationMetadata that)) { return false; + } return Objects.equals(this.finishReason, that.finishReason) && Objects.equals(this.seed, that.seed); } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index e1db5e2ac..35a980593 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; import java.util.List; import java.util.stream.Collectors; import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptions; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; @@ -51,6 +52,26 @@ public class StabilityAiImageModel implements ImageModel { this.defaultOptions = defaultOptions; } + private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, + StabilityAiImageOptions optionsToUse) { + return new StabilityAiApi.GenerateImageRequest.Builder() + .withTextPrompts(stabilityAiImagePrompt.getInstructions() + .stream() + .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), + message.getWeight())) + .collect(Collectors.toList())) + .withHeight(optionsToUse.getHeight()) + .withWidth(optionsToUse.getWidth()) + .withCfgScale(optionsToUse.getCfgScale()) + .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) + .withSampler(optionsToUse.getSampler()) + .withSamples(optionsToUse.getN()) + .withSeed(optionsToUse.getSeed()) + .withSteps(optionsToUse.getSteps()) + .withStylePreset(optionsToUse.getStylePreset()) + .build(); + } + public StabilityAiImageOptions getOptions() { return this.defaultOptions; } @@ -82,26 +103,6 @@ public class StabilityAiImageModel implements ImageModel { return convertResponse(generateImageResponse); } - private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, - StabilityAiImageOptions optionsToUse) { - return new StabilityAiApi.GenerateImageRequest.Builder() - .withTextPrompts(stabilityAiImagePrompt.getInstructions() - .stream() - .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), - message.getWeight())) - .collect(Collectors.toList())) - .withHeight(optionsToUse.getHeight()) - .withWidth(optionsToUse.getWidth()) - .withCfgScale(optionsToUse.getCfgScale()) - .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) - .withSampler(optionsToUse.getSampler()) - .withSamples(optionsToUse.getN()) - .withSeed(optionsToUse.getSeed()) - .withSteps(optionsToUse.getSteps()) - .withStylePreset(optionsToUse.getStylePreset()) - .build(); - } - private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) { List imageGenerationList = generateImageResponse.artifacts() .stream() diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java index e1d7c9efa..f3d76b3fa 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; /** @@ -48,7 +49,7 @@ public enum StyleEnum { @Override public String toString() { - return text; + return this.text; } } \ No newline at end of file diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java index 2ee5b2f7f..5b3b7f546 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai.api; import java.util.List; @@ -73,8 +74,8 @@ public class StabilityAiApi { Consumer jsonContentHeaders = headers -> { headers.setBearerAuth(apiKey); headers.setAccept(List.of(MediaType.APPLICATION_JSON)); // base64 in JSON + - // metadata or return - // image in bytes. + // metadata or return + // image in bytes. headers.setContentType(MediaType.APPLICATION_JSON); }; @@ -84,6 +85,15 @@ public class StabilityAiApi { .build(); } + public GenerateImageResponse generateImage(GenerateImageRequest request) { + Assert.notNull(request, "The request body can not be null."); + return this.restClient.post() + .uri("/generation/{model}/text-to-image", this.model) + .body(request) + .retrieve() + .body(GenerateImageResponse.class); + } + @JsonInclude(JsonInclude.Include.NON_NULL) public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts, @JsonProperty("height") Integer height, @JsonProperty("width") Integer width, @@ -92,15 +102,15 @@ public class StabilityAiApi { @JsonProperty("seed") Long seed, @JsonProperty("steps") Integer steps, @JsonProperty("style_present") String stylePreset) { + public static Builder builder() { + return new Builder(); + } + @JsonInclude(JsonInclude.Include.NON_NULL) public record TextPrompts(@JsonProperty("text") String text, @JsonProperty("weight") Float weight) { } - public static Builder builder() { - return new Builder(); - } - public static class Builder { List textPrompts; @@ -178,28 +188,23 @@ public class StabilityAiApi { } public GenerateImageRequest build() { - return new GenerateImageRequest(textPrompts, height, width, cfgScale, clipGuidancePreset, sampler, - samples, seed, steps, stylePreset); + return new GenerateImageRequest(this.textPrompts, this.height, this.width, this.cfgScale, + this.clipGuidancePreset, this.sampler, this.samples, this.seed, this.steps, this.stylePreset); } } + } @JsonInclude(JsonInclude.Include.NON_NULL) public record GenerateImageResponse(@JsonProperty("result") String result, @JsonProperty("artifacts") List artifacts) { + public record Artifacts(@JsonProperty("seed") long seed, @JsonProperty("base64") String base64, @JsonProperty("finishReason") String finishReason) { - } - } - public GenerateImageResponse generateImage(GenerateImageRequest request) { - Assert.notNull(request, "The request body can not be null."); - return this.restClient.post() - .uri("/generation/{model}/text-to-image", this.model) - .body(request) - .retrieve() - .body(GenerateImageResponse.class); + } + } } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index 4bf839e36..645e13f1a 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai.api; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.image.ImageOptions; import org.springframework.ai.stabilityai.StyleEnum; -import java.util.Objects; - /** * StabilityAiImageOptions is an interface that extends ImageOptions. It provides * additional stability AI specific image options. @@ -288,88 +290,9 @@ public class StabilityAiImageOptions implements ImageOptions { return new Builder(); } - public static class Builder { - - private final StabilityAiImageOptions options; - - private Builder() { - this.options = new StabilityAiImageOptions(); - } - - public Builder withN(Integer n) { - options.setN(n); - return this; - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public Builder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public Builder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public Builder withCfgScale(Float cfgScale) { - options.setCfgScale(cfgScale); - return this; - } - - public Builder withClipGuidancePreset(String clipGuidancePreset) { - options.setClipGuidancePreset(clipGuidancePreset); - return this; - } - - public Builder withSampler(String sampler) { - options.setSampler(sampler); - return this; - } - - public Builder withSeed(Long seed) { - options.setSeed(seed); - return this; - } - - public Builder withSteps(Integer steps) { - options.setSteps(steps); - return this; - } - - public Builder withSamples(Integer samples) { - options.setN(samples); - return this; - } - - public Builder withStylePreset(String stylePreset) { - options.setStylePreset(stylePreset); - return this; - } - - public Builder withStylePreset(StyleEnum styleEnum) { - options.setStylePreset(styleEnum.toString()); - return this; - } - - public StabilityAiImageOptions build() { - return options; - } - - } - @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -378,7 +301,7 @@ public class StabilityAiImageOptions implements ImageOptions { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -387,7 +310,7 @@ public class StabilityAiImageOptions implements ImageOptions { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -396,7 +319,7 @@ public class StabilityAiImageOptions implements ImageOptions { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -405,7 +328,7 @@ public class StabilityAiImageOptions implements ImageOptions { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -413,7 +336,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public Float getCfgScale() { - return cfgScale; + return this.cfgScale; } public void setCfgScale(Float cfgScale) { @@ -421,7 +344,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public String getClipGuidancePreset() { - return clipGuidancePreset; + return this.clipGuidancePreset; } public void setClipGuidancePreset(String clipGuidancePreset) { @@ -429,7 +352,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public String getSampler() { - return sampler; + return this.sampler; } public void setSampler(String sampler) { @@ -437,7 +360,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public Long getSeed() { - return seed; + return this.seed; } public void setSeed(Long seed) { @@ -445,7 +368,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public Integer getSteps() { - return steps; + return this.steps; } public void setSteps(Integer steps) { @@ -464,7 +387,7 @@ public class StabilityAiImageOptions implements ImageOptions { } public String getStylePreset() { - return stylePreset; + return this.stylePreset; } public void setStylePreset(String stylePreset) { @@ -473,30 +396,113 @@ public class StabilityAiImageOptions implements ImageOptions { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof StabilityAiImageOptions that)) + } + if (!(o instanceof StabilityAiImageOptions that)) { return false; - return Objects.equals(n, that.n) && Objects.equals(model, that.model) && Objects.equals(width, that.width) - && Objects.equals(height, that.height) && Objects.equals(responseFormat, that.responseFormat) - && Objects.equals(cfgScale, that.cfgScale) - && Objects.equals(clipGuidancePreset, that.clipGuidancePreset) && Objects.equals(sampler, that.sampler) - && Objects.equals(seed, that.seed) && Objects.equals(steps, that.steps) - && Objects.equals(stylePreset, that.stylePreset); + } + return Objects.equals(this.n, that.n) && Objects.equals(this.model, that.model) + && Objects.equals(this.width, that.width) && Objects.equals(this.height, that.height) + && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.cfgScale, that.cfgScale) + && Objects.equals(this.clipGuidancePreset, that.clipGuidancePreset) + && Objects.equals(this.sampler, that.sampler) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.steps, that.steps) && Objects.equals(this.stylePreset, that.stylePreset); } @Override public int hashCode() { - return Objects.hash(n, model, width, height, responseFormat, cfgScale, clipGuidancePreset, sampler, seed, steps, - stylePreset); + return Objects.hash(this.n, this.model, this.width, this.height, this.responseFormat, this.cfgScale, + this.clipGuidancePreset, this.sampler, this.seed, this.steps, this.stylePreset); } @Override public String toString() { - return "StabilityAiImageOptions{" + "n=" + n + ", model='" + model + '\'' + ", width=" + width + ", height=" - + height + ", responseFormat='" + responseFormat + '\'' + ", cfgScale=" + cfgScale - + ", clipGuidancePreset='" + clipGuidancePreset + '\'' + ", sampler='" + sampler + '\'' + ", seed=" - + seed + ", steps=" + steps + ", stylePreset='" + stylePreset + '\'' + '}'; + return "StabilityAiImageOptions{" + "n=" + this.n + ", model='" + this.model + '\'' + ", width=" + this.width + + ", height=" + this.height + ", responseFormat='" + this.responseFormat + '\'' + ", cfgScale=" + + this.cfgScale + ", clipGuidancePreset='" + this.clipGuidancePreset + '\'' + ", sampler='" + + this.sampler + '\'' + ", seed=" + this.seed + ", steps=" + this.steps + ", stylePreset='" + + this.stylePreset + '\'' + '}'; + } + + public static class Builder { + + private final StabilityAiImageOptions options; + + private Builder() { + this.options = new StabilityAiImageOptions(); + } + + public Builder withN(Integer n) { + this.options.setN(n); + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder withCfgScale(Float cfgScale) { + this.options.setCfgScale(cfgScale); + return this; + } + + public Builder withClipGuidancePreset(String clipGuidancePreset) { + this.options.setClipGuidancePreset(clipGuidancePreset); + return this; + } + + public Builder withSampler(String sampler) { + this.options.setSampler(sampler); + return this; + } + + public Builder withSeed(Long seed) { + this.options.setSeed(seed); + return this; + } + + public Builder withSteps(Integer steps) { + this.options.setSteps(steps); + return this; + } + + public Builder withSamples(Integer samples) { + this.options.setN(samples); + return this; + } + + public Builder withStylePreset(String stylePreset) { + this.options.setStylePreset(stylePreset); + return this; + } + + public Builder withStylePreset(StyleEnum styleEnum) { + this.options.setStylePreset(styleEnum.toString()); + return this; + } + + public StabilityAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java index 0b0c49ee7..bd98867f9 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.stabilityai; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.stabilityai.api.StabilityAiApi; +package org.springframework.ai.stabilityai; import java.io.File; import java.io.FileOutputStream; @@ -25,6 +22,11 @@ import java.io.IOException; import java.util.Base64; import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; + import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") @@ -32,6 +34,21 @@ public class StabilityAiApiIT { StabilityAiApi stabilityAiApi = new StabilityAiApi(System.getenv("STABILITYAI_API_KEY")); + private static void writeToFile(List artifacts) throws IOException { + int counter = 0; + String systemTempDir = System.getProperty("java.io.tmpdir"); + for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { + counter++; + byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); + String fileName = String.format("dog%d.png", counter); + String filePath = systemTempDir + File.separator + fileName; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + } + @Test void generateImage() throws IOException { @@ -48,7 +65,7 @@ public class StabilityAiApiIT { .withSteps(30) .withStylePreset("photographic"); StabilityAiApi.GenerateImageRequest request = builder.build(); - StabilityAiApi.GenerateImageResponse response = stabilityAiApi.generateImage(request); + StabilityAiApi.GenerateImageResponse response = this.stabilityAiApi.generateImage(request); assertThat(response).isNotNull(); List artifacts = response.artifacts(); @@ -61,19 +78,4 @@ public class StabilityAiApiIT { } - private static void writeToFile(List artifacts) throws IOException { - int counter = 0; - String systemTempDir = System.getProperty("java.io.tmpdir"); - for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { - counter++; - byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); - String fileName = String.format("dog%d.png", counter); - String filePath = systemTempDir + File.separator + fileName; - File file = new File(filePath); - try (FileOutputStream fos = new FileOutputStream(file)) { - fos.write(imageBytes); - } - } - } - } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java index b2de03a19..9548cefca 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - -import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; -import org.springframework.ai.image.ImageGeneration; -import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.image.ImageResponse; -import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.SpringBootTest; - import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.util.Base64; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.image.Image; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = StabilityAiImageTestConfiguration.class) @@ -41,6 +42,16 @@ public class StabilityAiImageModelIT { @Autowired protected ImageModel stabilityAiImageModel; + private static void writeFile(Image image) throws IOException { + byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); + String systemTempDir = System.getProperty("java.io.tmpdir"); + String filePath = systemTempDir + File.separator + "dog.png"; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + @Test void imageAsBase64Test() throws IOException { @@ -64,14 +75,4 @@ public class StabilityAiImageModelIT { writeFile(image); } - private static void writeFile(Image image) throws IOException { - byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); - String systemTempDir = System.getProperty("java.io.tmpdir"); - String filePath = systemTempDir + File.separator + "dog.png"; - File file = new File(filePath); - try (FileOutputStream fos = new FileOutputStream(file)) { - fos.write(imageBytes); - } - } - } diff --git a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java index c5271ff00..27690f7b5 100644 --- a/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java +++ b/models/spring-ai-stability-ai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiApi; diff --git a/models/spring-ai-transformers/pom.xml b/models/spring-ai-transformers/pom.xml index 086e0064f..f32266818 100644 --- a/models/spring-ai-transformers/pom.xml +++ b/models/spring-ai-transformers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java index 8fa20b199..a07457182 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.io.File; diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java index e98db7ef2..213f578aa 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.nio.FloatBuffer; @@ -23,8 +24,22 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.nlp.preprocess.Tokenizer; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -43,20 +58,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import ai.djl.huggingface.tokenizers.Encoding; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.djl.modality.nlp.preprocess.Tokenizer; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.ndarray.types.Shape; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; -import io.micrometer.observation.ObservationRegistry; - /** * An implementation of the AbstractEmbeddingModel that uses ONNX-based Transformer models * for text embeddings. @@ -79,10 +80,6 @@ import io.micrometer.observation.ObservationRegistry; */ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implements InitializingBean { - private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class); - - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - // ONNX tokenizer for the all-MiniLM-L6-v2 generative public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; @@ -92,8 +89,27 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement public final static String DEFAULT_MODEL_OUTPUT_NAME = "last_hidden_state"; + private static final Log logger = LogFactory.getLog(TransformersEmbeddingModel.class); + + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final static int EMBEDDING_AXIS = 1; + /** + * Specifies what parts of the {@link Document}'s content and metadata will be used + * for computing the embeddings. Applicable for the {@link #embed(Document)} method + * only. Has no effect on the {@link #embed(String)} or {@link #embed(List)}. Defaults + * to {@link MetadataMode#NONE}. + */ + private final MetadataMode metadataMode; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + public Map tokenizerOptions = Map.of(); + private Resource tokenizerResource = toResource(DEFAULT_ONNX_TOKENIZER_URI); private Resource modelResource = toResource(DEFAULT_ONNX_MODEL_URI); @@ -116,14 +132,6 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement */ private OrtSession session; - /** - * Specifies what parts of the {@link Document}'s content and metadata will be used - * for computing the embeddings. Applicable for the {@link #embed(Document)} method - * only. Has no effect on the {@link #embed(String)} or {@link #embed(List)}. Defaults - * to {@link MetadataMode#NONE}. - */ - private final MetadataMode metadataMode; - /** * Resource cache directory. Used to cache remote resources, such as the ONNX models, * to the local file system. @@ -143,17 +151,10 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement */ private ResourceCacheService cacheService; - public Map tokenizerOptions = Map.of(); - private String modelOutputName = DEFAULT_MODEL_OUTPUT_NAME; private Set onnxModelInputs; - /** - * Observation registry used for instrumentation. - */ - private final ObservationRegistry observationRegistry; - /** * Conventions to use for generating observations. */ @@ -174,6 +175,10 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement this.observationRegistry = observationRegistry; } + private static Resource toResource(String uri) { + return new DefaultResourceLoader().getResource(uri); + } + public void setTokenizerOptions(Map tokenizerOptions) { this.tokenizerOptions = tokenizerOptions; } @@ -360,7 +365,7 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement return modelInputs.entrySet() .stream() - .filter(a -> onnxModelInputs.contains(a.getKey())) + .filter(a -> this.onnxModelInputs.contains(a.getKey())) .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); } @@ -399,10 +404,6 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement return sumEmbeddings.div(sumMask); } - private static Resource toResource(String uri) { - return new DefaultResourceLoader().getResource(uri); - } - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention @@ -412,4 +413,4 @@ public class TransformersEmbeddingModel extends AbstractEmbeddingModel implement this.observationConvention = observationConvention; } -} \ No newline at end of file +} diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java index a8da32222..3e6aff413 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/ResourceCacheServiceTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.io.File; @@ -37,27 +38,27 @@ public class ResourceCacheServiceTests { @Test public void fileResourcesAreExcludedByDefault() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource = cache.getCachedResource(originalResourceUri); assertThat(cachedResource).isEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(0); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(0); } @Test public void cacheFileResources() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names, - // including 'file'. + // including 'file'. var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource1 = cache.getCachedResource(originalResourceUri); assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); // Attempt to cache the same resource again should return the already cached // resource. @@ -66,17 +67,17 @@ public class ResourceCacheServiceTests { assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri)); assertThat(cachedResource2).isEqualTo(cachedResource1); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); } @Test public void cacheFileResourcesFromSameParentFolder() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names, - // including 'file'. + // including 'file'. var originalResourceUri1 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; var cachedResource1 = cache.getCachedResource(originalResourceUri1); @@ -89,23 +90,23 @@ public class ResourceCacheServiceTests { assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1)); assertThat(cachedResource2).isNotEqualTo(cachedResource1); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1) + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1) .describedAs( "As both resources come from the same parent segments they should be cached in a single common parent."); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(2); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(2); } @Test public void cacheHttpResources() throws IOException { - var cache = new ResourceCacheService(tempDir); + var cache = new ResourceCacheService(this.tempDir); var originalResourceUri1 = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties"; var cachedResource1 = cache.getCachedResource(originalResourceUri1); assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1)); - assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1); - assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1); + assertThat(Files.list(this.tempDir.toPath()).count()).isEqualTo(1); + assertThat(Files.list(Files.list(this.tempDir.toPath()).iterator().next()).count()).isEqualTo(1); } } diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java index ec3c9c5ad..3f91ad52e 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelObservationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.transformers; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.transformers; import java.util.List; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; @@ -35,8 +37,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. @@ -59,13 +60,13 @@ public class TransformersEmbeddingModelObservationTests { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java index 8c2fb3f01..02056499b 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers; import java.text.DecimalFormat; diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java index be37188b3..8119bbca5 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformers.samples; import java.nio.FloatBuffer; @@ -125,4 +126,4 @@ public class ONNXSample { return manager.create(buffer, new Shape(data.length, data[0].length, data[0][0].length)); } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml index b94de96bd..0ce34354e 100644 --- a/models/spring-ai-vertex-ai-embedding/pom.xml +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java index 3f2fb6ec7..7f2cc2eb4 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding; import java.io.IOException; -import org.springframework.util.StringUtils; - import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import org.springframework.util.StringUtils; + /** * VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex * AI embedding service. It provides methods to access the project ID, location, @@ -33,15 +34,13 @@ import com.google.cloud.aiplatform.v1.PredictionServiceSettings; */ public class VertexAiEmbeddingConnectionDetails { - private static final String DEFAULT_LOCATION = "us-central1"; - public static final String DEFAULT_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; public static final String DEFAULT_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com:443"; public static final String DEFAULT_PUBLISHER = "google"; - private PredictionServiceSettings predictionServiceSettings; + private static final String DEFAULT_LOCATION = "us-central1"; /** * Your project ID. @@ -59,6 +58,8 @@ public class VertexAiEmbeddingConnectionDetails { private final String publisher; + private PredictionServiceSettings predictionServiceSettings; + public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) { this.projectId = projectId; this.location = location; @@ -76,6 +77,27 @@ public class VertexAiEmbeddingConnectionDetails { return new Builder(); } + public String getProjectId() { + return this.projectId; + } + + public String getLocation() { + return this.location; + } + + public String getPublisher() { + return this.publisher; + } + + public EndpointName getEndpointName(String modelName) { + return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, + modelName); + } + + public PredictionServiceSettings getPredictionServiceSettings() { + return this.predictionServiceSettings; + } + public static class Builder { /** @@ -143,25 +165,4 @@ public class VertexAiEmbeddingConnectionDetails { } - public String getProjectId() { - return this.projectId; - } - - public String getLocation() { - return this.location; - } - - public String getPublisher() { - return this.publisher; - } - - public EndpointName getEndpointName(String modelName) { - return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, - modelName); - } - - public PredictionServiceSettings getPredictionServiceSettings() { - return this.predictionServiceSettings; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java index ef0152c23..602afbd80 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vertexai.embedding; import org.springframework.ai.chat.metadata.Usage; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java index a160baeda..caac760df 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java @@ -1,33 +1,33 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vertexai.embedding; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.List; - -import org.springframework.util.Assert; -import org.springframework.util.MimeType; -import org.springframework.util.StringUtils; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Struct; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.StringUtils; + /** * Utility class for constructing parameter objects for Vertex AI embedding requests. * @@ -36,6 +36,39 @@ import com.google.protobuf.util.JsonFormat; */ public abstract class VertexAiEmbeddingUtils { + public static Value valueOf(boolean n) { + return Value.newBuilder().setBoolValue(n).build(); + } + + public static Value valueOf(String s) { + return Value.newBuilder().setStringValue(s).build(); + } + + public static Value valueOf(int n) { + return Value.newBuilder().setNumberValue(n).build(); + } + + public static Value valueOf(Struct struct) { + return Value.newBuilder().setStructValue(struct).build(); + } + + // Convert a Json string to a protobuf.Value + public static Value jsonToValue(String json) throws InvalidProtocolBufferException { + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + public static float[] toVector(Value value) { + float[] floats = new float[value.getListValue().getValuesList().size()]; + int index = 0; + for (Value v : value.getListValue().getValuesList()) { + double d = v.getNumberValue(); + floats[index++] = Double.valueOf(d).floatValue(); + } + return floats; + } + ////////////////////////////////////////////////////// // Text Only ////////////////////////////////////////////////////// @@ -404,37 +437,4 @@ public abstract class VertexAiEmbeddingUtils { } - public static Value valueOf(boolean n) { - return Value.newBuilder().setBoolValue(n).build(); - } - - public static Value valueOf(String s) { - return Value.newBuilder().setStringValue(s).build(); - } - - public static Value valueOf(int n) { - return Value.newBuilder().setNumberValue(n).build(); - } - - public static Value valueOf(Struct struct) { - return Value.newBuilder().setStructValue(struct).build(); - } - - // Convert a Json string to a protobuf.Value - public static Value jsonToValue(String json) throws InvalidProtocolBufferException { - Value.Builder builder = Value.newBuilder(); - JsonFormat.parser().merge(json, builder); - return builder.build(); - } - - public static float[] toVector(Value value) { - float[] floats = new float[value.getListValue().getValuesList().size()]; - int index = 0; - for (Value v : value.getListValue().getValuesList()) { - double d = v.getNumberValue(); - floats[index++] = Double.valueOf(d).floatValue(); - } - return floats; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index 5dca008c0..b13efe0dd 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.multimodal; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; @@ -23,7 +31,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Value; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.Media; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingModel; @@ -34,6 +42,7 @@ import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; +import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; @@ -47,13 +56,6 @@ import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; - /** * Implementation of the Vertex AI Multimodal Embedding Model. Note: This implementation * is not yet fully functional and is subject to change. @@ -66,8 +68,6 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final Logger logger = LoggerFactory.getLogger(VertexAiMultimodalEmbeddingModel.class); - public final VertexAiMultimodalEmbeddingOptions defaultOptions; - private static final MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); private static final MimeType IMAGE_MIME_TYPE = MimeTypeUtils.parseMimeType("image/*"); @@ -77,6 +77,13 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp")); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiMultimodalEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, + VertexAiMultimodalEmbeddingModelName::getDimensions)); + + public final VertexAiMultimodalEmbeddingOptions defaultOptions; + private final VertexAiEmbeddingConnectionDetails connectionDetails; public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, @@ -123,9 +130,6 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel return finalResponse; } - record DocumentMetadata(String documentId, MimeType mimeType, Object data) { - } - private EmbeddingResponse doSingleDocumentPrediction(PredictionServiceClient client, EndpointName endpointName, Document document, VertexAiMultimodalEmbeddingOptions mergedOptions) throws InvalidProtocolBufferException { @@ -252,9 +256,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), 768); } - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiMultimodalEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, - VertexAiMultimodalEmbeddingModelName::getDimensions)); + record DocumentMetadata(String documentId, MimeType mimeType, Object data) { -} \ No newline at end of file + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java index 5dc546b5f..750d9816a 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.multimodal; import org.springframework.ai.model.EmbeddingModelDescription; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java index 78a75fef2..89762581c 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.multimodal; -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.util.StringUtils; +package org.springframework.ai.vertexai.embedding.multimodal; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + /** * Class representing the options for Vertex AI Multimodal Embedding. * @@ -105,6 +106,48 @@ public class VertexAiMultimodalEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public Integer getVideoStartOffsetSec() { + return this.videoStartOffsetSec; + } + + public void setVideoStartOffsetSec(Integer videoStartOffsetSec) { + this.videoStartOffsetSec = videoStartOffsetSec; + } + + public Integer getVideoEndOffsetSec() { + return this.videoEndOffsetSec; + } + + public void setVideoEndOffsetSec(Integer videoEndOffsetSec) { + this.videoEndOffsetSec = videoEndOffsetSec; + } + + public Integer getVideoIntervalSec() { + return this.videoIntervalSec; + } + + public void setVideoIntervalSec(Integer videoIntervalSec) { + this.videoIntervalSec = videoIntervalSec; + } + public static class Builder { protected VertexAiMultimodalEmbeddingOptions options; @@ -168,46 +211,4 @@ public class VertexAiMultimodalEmbeddingOptions implements EmbeddingOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public Integer getVideoStartOffsetSec() { - return this.videoStartOffsetSec; - } - - public void setVideoStartOffsetSec(Integer videoStartOffsetSec) { - this.videoStartOffsetSec = videoStartOffsetSec; - } - - public Integer getVideoEndOffsetSec() { - return this.videoEndOffsetSec; - } - - public void setVideoEndOffsetSec(Integer videoEndOffsetSec) { - this.videoEndOffsetSec = videoEndOffsetSec; - } - - public Integer getVideoIntervalSec() { - return this.videoIntervalSec; - } - - public void setVideoIntervalSec(Integer videoIntervalSec) { - this.videoIntervalSec = videoIntervalSec; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 26da5fea0..31d2846e9 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.text; import java.io.IOException; @@ -22,6 +23,13 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -46,14 +54,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictResponse; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.protobuf.Value; - -import io.micrometer.observation.ObservationRegistry; - /** * A class representing a Vertex AI Text Embedding Model. * @@ -65,6 +65,11 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiTextEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, + VertexAiTextEmbeddingModelName::getDimensions)); + public final VertexAiTextEmbeddingOptions defaultOptions; private final VertexAiEmbeddingConnectionDetails connectionDetails; @@ -131,7 +136,7 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, finalOptions); - PredictResponse embeddingResponse = retryTemplate + PredictResponse embeddingResponse = this.retryTemplate .execute(context -> getPredictResponse(client, predictRequestBuilder)); int index = 0; @@ -228,11 +233,6 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); } - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiTextEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, - VertexAiTextEmbeddingModelName::getDimensions)); - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention @@ -242,4 +242,4 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { this.observationConvention = observationConvention; } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java index c49471d06..327d7950c 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.embedding.text; import org.springframework.ai.model.EmbeddingModelDescription; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java index fe08b2d4b..4de1f3375 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.util.StringUtils; +package org.springframework.ai.vertexai.embedding.text; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + /** * @author Christian Tzolov * @since 1.0.0 @@ -31,6 +32,100 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions { public static final String DEFAULT_MODEL_NAME = VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName(); + /** + * The embedding model name to use. Supported models are: text-embedding-004, + * text-multilingual-embedding-002 and multimodalembedding@001. + */ + private @JsonProperty("model") String model; + + // @formatter:off + + /** + * The intended downstream application to help the model produce better quality embeddings. + * Not all model versions support all task types. + */ + private @JsonProperty("task") TaskType taskType; + + /** + * The number of dimensions the resulting output embeddings should have. + * Supported for model version 004 and later. You can use this parameter to reduce the + * embedding size, for example, for storage optimization. + */ + private @JsonProperty("dimensions") Integer dimensions; + + /** + * Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. + */ + private @JsonProperty("title") String title; + + /** + * When set to true, input text will be truncated. When set to false, an error is returned + * if the input text is longer than the maximum length supported by the model. Defaults to true. + */ + private @JsonProperty("autoTruncate") Boolean autoTruncate; + + public static Builder builder() { + return new Builder(); + } + + + // @formatter:on + + public VertexAiTextEmbeddingOptions initializeDefaults() { + + if (this.getTaskType() == null) { + this.setTaskType(TaskType.RETRIEVAL_DOCUMENT); + } + + if (StringUtils.hasText(this.getTitle()) && this.getTaskType() != TaskType.RETRIEVAL_DOCUMENT) { + throw new IllegalArgumentException("Title is only valid with task_type=RETRIEVAL_DOCUMENT"); + } + + return this; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public TaskType getTaskType() { + return this.taskType; + } + + public void setTaskType(TaskType taskType) { + this.taskType = taskType; + } + + @Override + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public String getTitle() { + return this.title; + } + + public void setTitle(String user) { + this.title = user; + } + + public Boolean getAutoTruncate() { + return this.autoTruncate; + } + + public void setAutoTruncate(Boolean autoTruncate) { + this.autoTruncate = autoTruncate; + } + public enum TaskType { /** @@ -71,45 +166,6 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions { } - // @formatter:off - /** - * The embedding model name to use. Supported models are: - * text-embedding-004, text-multilingual-embedding-002 and multimodalembedding@001. - */ - private @JsonProperty("model") String model; - - /** - * The intended downstream application to help the model produce better quality embeddings. - * Not all model versions support all task types. - */ - private @JsonProperty("task") TaskType taskType; - - /** - * The number of dimensions the resulting output embeddings should have. - * Supported for model version 004 and later. You can use this parameter to reduce the - * embedding size, for example, for storage optimization. - */ - private @JsonProperty("dimensions") Integer dimensions; - - /** - * Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. - */ - private @JsonProperty("title") String title; - - - /** - * When set to true, input text will be truncated. When set to false, an error is returned - * if the input text is longer than the maximum length supported by the model. Defaults to true. - */ - private @JsonProperty("autoTruncate") Boolean autoTruncate; - - - // @formatter:on - - public static Builder builder() { - return new Builder(); - } - public static class Builder { protected VertexAiTextEmbeddingOptions options; @@ -170,59 +226,4 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions { } - public VertexAiTextEmbeddingOptions initializeDefaults() { - - if (this.getTaskType() == null) { - this.setTaskType(TaskType.RETRIEVAL_DOCUMENT); - } - - if (StringUtils.hasText(this.getTitle()) && this.getTaskType() != TaskType.RETRIEVAL_DOCUMENT) { - throw new IllegalArgumentException("Title is only valid with task_type=RETRIEVAL_DOCUMENT"); - } - - return this; - } - - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public TaskType getTaskType() { - return this.taskType; - } - - public void setTaskType(TaskType taskType) { - this.taskType = taskType; - } - - @Override - public Integer getDimensions() { - return this.dimensions; - } - - public void setDimensions(Integer dimensions) { - this.dimensions = dimensions; - } - - public String getTitle() { - return this.title; - } - - public void setTitle(String user) { - this.title = user; - } - - public Boolean getAutoTruncate() { - return this.autoTruncate; - } - - public void setAutoTruncate(Boolean autoTruncate) { - this.autoTruncate = autoTruncate; - } - } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java index 6caf87432..b92079d55 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.multimodal; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.multimodal; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.model.Media; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.model.Media; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -33,6 +33,8 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest(classes = VertexAiMultimodalEmbeddingModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") @@ -49,7 +51,7 @@ class VertexAiMultimodalEmbeddingModelIT { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(new Document("Hello World"), new Document("Hello World2")); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -76,7 +78,7 @@ class VertexAiMultimodalEmbeddingModelIT { .as("Total tokens in metadata should be 0") .isEqualTo(0L); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -86,7 +88,7 @@ class VertexAiMultimodalEmbeddingModelIT { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -98,18 +100,18 @@ class VertexAiMultimodalEmbeddingModelIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test void textMediaEmbedding() { - assertThat(multiModelEmbeddingModel).isNotNull(); + assertThat(this.multiModelEmbeddingModel).isNotNull(); var document = Document.builder().withMedia(new Media(MimeTypeUtils.TEXT_PLAIN, "Hello World")).build(); DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -121,7 +123,7 @@ class VertexAiMultimodalEmbeddingModelIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -133,7 +135,7 @@ class VertexAiMultimodalEmbeddingModelIT { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -147,7 +149,7 @@ class VertexAiMultimodalEmbeddingModelIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -159,7 +161,7 @@ class VertexAiMultimodalEmbeddingModelIT { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); @@ -172,7 +174,7 @@ class VertexAiMultimodalEmbeddingModelIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @Test @@ -186,7 +188,7 @@ class VertexAiMultimodalEmbeddingModelIT { DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); - EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.multiModelEmbeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(3); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) @@ -206,7 +208,7 @@ class VertexAiMultimodalEmbeddingModelIT { assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001"); assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0); - assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + assertThat(this.multiModelEmbeddingModel.dimensions()).isEqualTo(1408); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java index 090d683e5..baaeedd8a 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,12 +20,11 @@ import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; - public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { private PredictionServiceClient mockPredictionServiceClient; @@ -43,16 +42,16 @@ public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { @Override PredictionServiceClient createPredictionServiceClient() { - if (mockPredictionServiceClient != null) { - return mockPredictionServiceClient; + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient; } return super.createPredictionServiceClient(); } @Override PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { - if (mockPredictionServiceClient != null) { - return mockPredictionServiceClient.predict(predictRequestBuilder.build()); + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient.predict(predictRequestBuilder.build()); } return super.getPredictResponse(client, predictRequestBuilder); } @@ -64,8 +63,8 @@ public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { @Override protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, VertexAiTextEmbeddingOptions finalOptions) { - if (mockPredictRequestBuilder != null) { - return mockPredictRequestBuilder; + if (this.mockPredictRequestBuilder != null) { + return this.mockPredictRequestBuilder; } return super.getPredictRequestBuilder(request, endpointName, finalOptions); } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java index f98c96b1b..d1701b7a8 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.text; import java.util.List; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -30,6 +30,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") @@ -43,11 +45,11 @@ class VertexAiTextEmbeddingModelIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "text-embedding-004", "text-multilingual-embedding-002" }) void defaultEmbedding(String modelName) { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); var options = VertexAiTextEmbeddingOptions.builder().withModel(modelName).build(); - EmbeddingResponse embeddingResponse = embeddingModel + EmbeddingResponse embeddingResponse = this.embeddingModel .call(new EmbeddingRequest(List.of("Hello World", "World is Big"), options)); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -60,7 +62,7 @@ class VertexAiTextEmbeddingModelIT { .as("Total tokens in metadata should be 5") .isEqualTo(5L); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java index f6ac7c5b5..9a277d403 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.embedding.text; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.embedding.text; import java.util.List; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -36,9 +39,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation in {@link OpenAiEmbeddingModel}. @@ -66,13 +67,13 @@ public class VertexAiTextEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java index 2638e97e7..5757fe5a4 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,11 @@ package org.springframework.ai.vertexai.embedding.text; +import java.util.List; + import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -27,8 +29,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.embedding.EmbeddingResponse; + import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -37,8 +40,6 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; @@ -52,25 +53,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class VertexAiTextEmbeddingRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -91,15 +73,15 @@ public class VertexAiTextEmbeddingRetryTests { @BeforeEach public void setUp() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - embeddingModel = new TestVertexAiTextEmbeddingModel(mockConnectionDetails, - VertexAiTextEmbeddingOptions.builder().build(), retryTemplate); - embeddingModel.setMockPredictionServiceClient(mockPredictionServiceClient); - embeddingModel.setMockPredictRequestBuilder(mockPredictRequestBuilder); - when(mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); + this.embeddingModel = new TestVertexAiTextEmbeddingModel(this.mockConnectionDetails, + VertexAiTextEmbeddingOptions.builder().build(), this.retryTemplate); + this.embeddingModel.setMockPredictionServiceClient(this.mockPredictionServiceClient); + this.embeddingModel.setMockPredictRequestBuilder(this.mockPredictRequestBuilder); + when(this.mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); } @Test @@ -130,32 +112,51 @@ public class VertexAiTextEmbeddingRetryTests { .build(); // Setup the mock PredictionServiceClient - when(mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) + when(this.mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(mockResponse); - EmbeddingResponse result = embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); + EmbeddingResponse result = this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResults()).hasSize(1); assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); - verify(mockPredictRequestBuilder, times(3)).build(); + verify(this.mockPredictRequestBuilder, times(3)).build(); } @Test public void vertexAiEmbeddingNonTransientError() { // Setup the mock PredictionServiceClient to throw a non-transient error - when(mockPredictionServiceClient.predict(any())).thenThrow(new RuntimeException("Non Transient Error")); + when(this.mockPredictionServiceClient.predict(any())).thenThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown and not retried assertThrows(RuntimeException.class, - () -> embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))); + () -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))); // Verify that predict was called only once (no retries for non-transient errors) - verify(mockPredictionServiceClient, times(1)).predict(any()); + verify(this.mockPredictionServiceClient, times(1)).predict(any()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml index 57f5dcd5c..230c5cd67 100644 --- a/models/spring-ai-vertex-ai-gemini/pom.xml +++ b/models/spring-ai-vertex-ai-gemini/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java index c1d7c34bb..fe5e8e52e 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.io.File; @@ -56,21 +57,6 @@ public abstract class MimeTypeDetector { */ private static final Map GEMINI_MIME_TYPES = new HashMap<>(); - static { - // Custom MIME type mappings here - GEMINI_MIME_TYPES.put("png", MimeTypeUtils.IMAGE_PNG); - GEMINI_MIME_TYPES.put("jpeg", MimeTypeUtils.IMAGE_JPEG); - GEMINI_MIME_TYPES.put("jpg", MimeTypeUtils.IMAGE_JPEG); - GEMINI_MIME_TYPES.put("gif", MimeTypeUtils.IMAGE_GIF); - GEMINI_MIME_TYPES.put("mov", new MimeType("video", "mov")); - GEMINI_MIME_TYPES.put("mp4", new MimeType("video", "mp4")); - GEMINI_MIME_TYPES.put("mpg", new MimeType("video", "mpg")); - GEMINI_MIME_TYPES.put("avi", new MimeType("video", "avi")); - GEMINI_MIME_TYPES.put("wmv", new MimeType("video", "wmv")); - GEMINI_MIME_TYPES.put("mpegps", new MimeType("mpegps", "mp4")); - GEMINI_MIME_TYPES.put("flv", new MimeType("video", "flv")); - } - public static MimeType getMimeType(URL url) { return getMimeType(url.getFile()); } @@ -115,4 +101,19 @@ public abstract class MimeTypeDetector { String.format("Unable to detect the MIME type of '%s'. Please provide it explicitly.", path)); } + static { + // Custom MIME type mappings here + GEMINI_MIME_TYPES.put("png", MimeTypeUtils.IMAGE_PNG); + GEMINI_MIME_TYPES.put("jpeg", MimeTypeUtils.IMAGE_JPEG); + GEMINI_MIME_TYPES.put("jpg", MimeTypeUtils.IMAGE_JPEG); + GEMINI_MIME_TYPES.put("gif", MimeTypeUtils.IMAGE_GIF); + GEMINI_MIME_TYPES.put("mov", new MimeType("video", "mov")); + GEMINI_MIME_TYPES.put("mp4", new MimeType("video", "mp4")); + GEMINI_MIME_TYPES.put("mpg", new MimeType("video", "mpg")); + GEMINI_MIME_TYPES.put("avi", new MimeType("video", "avi")); + GEMINI_MIME_TYPES.put("wmv", new MimeType("video", "wmv")); + GEMINI_MIME_TYPES.put("mpegps", new MimeType("mpegps", "mp4")); + GEMINI_MIME_TYPES.put("flv", new MimeType("video", "flv")); + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index dfe008a47..67d6bed36 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.*; +import com.google.cloud.vertexai.api.Candidate; import com.google.cloud.vertexai.api.Candidate.FinishReason; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; +import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.FunctionResponse; +import com.google.cloud.vertexai.api.GenerateContentResponse; +import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.GoogleSearchRetrieval; +import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.PartMaker; import com.google.cloud.vertexai.generativeai.ResponseStream; import com.google.protobuf.Struct; import com.google.protobuf.util.JsonFormat; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -60,18 +83,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * @author Christian Tzolov * @author Grogdunn @@ -106,54 +117,6 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - public enum GeminiMessageType { - - USER("user"), - - MODEL("model"); - - GeminiMessageType(String value) { - this.value = value; - } - - public final String value; - - public String getValue() { - return this.value; - } - - } - - public enum ChatModel implements ChatModelDescription { - - /** - * Deprecated by Goolgle in favor of 1.5 pro and flash models. - */ - GEMINI_PRO_VISION("gemini-pro-vision"), - - GEMINI_PRO("gemini-pro"), - - GEMINI_1_5_PRO("gemini-1.5-pro-001"), - - GEMINI_1_5_FLASH("gemini-1.5-flash-001"); - - ChatModel(String value) { - this.value = value; - } - - public final String value; - - public String getValue() { - return this.value; - } - - @Override - public String getName() { - return this.value; - } - - } - public VertexAiGeminiChatModel(VertexAI vertexAI) { this(vertexAI, VertexAiGeminiChatOptions.builder().withModel(ChatModel.GEMINI_1_5_PRO).withTemperature(0.8).build()); @@ -198,6 +161,124 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements this.observationRegistry = observationRegistry; } + private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type) { + + Assert.notNull(type, "Message type must not be null"); + + switch (type) { + case SYSTEM: + case USER: + case TOOL: + return GeminiMessageType.USER; + case ASSISTANT: + return GeminiMessageType.MODEL; + default: + throw new IllegalArgumentException("Unsupported message type: " + type); + } + } + + static List messageToGeminiParts(Message message) { + + if (message instanceof SystemMessage systemMessage) { + + List parts = new ArrayList<>(); + + if (systemMessage.getContent() != null) { + parts.add(Part.newBuilder().setText(systemMessage.getContent()).build()); + } + + return parts; + } + else if (message instanceof UserMessage userMessage) { + List parts = new ArrayList<>(); + if (userMessage.getContent() != null) { + parts.add(Part.newBuilder().setText(userMessage.getContent()).build()); + } + + parts.addAll(mediaToParts(userMessage.getMedia())); + + return parts; + } + else if (message instanceof AssistantMessage assistantMessage) { + List parts = new ArrayList<>(); + if (StringUtils.hasText(assistantMessage.getContent())) { + parts.add(Part.newBuilder().setText(assistantMessage.getContent()).build()); + } + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + parts.addAll(assistantMessage.getToolCalls() + .stream() + .map(toolCall -> Part.newBuilder() + .setFunctionCall(FunctionCall.newBuilder() + .setName(toolCall.name()) + .setArgs(jsonToStruct(toolCall.arguments())) + .build()) + .build()) + .toList()); + } + return parts; + } + else if (message instanceof ToolResponseMessage toolResponseMessage) { + + return toolResponseMessage.getResponses() + .stream() + .map(response -> Part.newBuilder() + .setFunctionResponse(FunctionResponse.newBuilder() + .setName(response.name()) + .setResponse(jsonToStruct(response.responseData())) + .build()) + .build()) + .toList(); + } + else { + throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass()); + } + } + + private static List mediaToParts(Collection media) { + List parts = new ArrayList<>(); + + List mediaParts = media.stream() + .map(mediaData -> PartMaker.fromMimeTypeAndData(mediaData.getMimeType().toString(), mediaData.getData())) + .toList(); + + if (!CollectionUtils.isEmpty(mediaParts)) { + parts.addAll(mediaParts); + } + + return parts; + } + + private static String structToJson(Struct struct) { + try { + return JsonFormat.printer().print(struct); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Struct jsonToStruct(String json) { + try { + var structBuilder = Struct.newBuilder(); + JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder); + return structBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Schema jsonToSchema(String json) { + try { + var schemaBuilder = Schema.newBuilder(); + JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder); + return schemaBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { @@ -346,10 +427,6 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build(); } - @JsonInclude(Include.NON_NULL) - public record GeminiRequest(List contents, GenerativeModel model) { - } - private VertexAiGeminiChatOptions vertexAiGeminiChatOptions(Prompt prompt) { VertexAiGeminiChatOptions updatedRuntimeOptions = VertexAiGeminiChatOptions.builder().build(); if (prompt.getOptions() != null) { @@ -480,93 +557,6 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements return contents; } - private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type) { - - Assert.notNull(type, "Message type must not be null"); - - switch (type) { - case SYSTEM: - case USER: - case TOOL: - return GeminiMessageType.USER; - case ASSISTANT: - return GeminiMessageType.MODEL; - default: - throw new IllegalArgumentException("Unsupported message type: " + type); - } - } - - static List messageToGeminiParts(Message message) { - - if (message instanceof SystemMessage systemMessage) { - - List parts = new ArrayList<>(); - - if (systemMessage.getContent() != null) { - parts.add(Part.newBuilder().setText(systemMessage.getContent()).build()); - } - - return parts; - } - else if (message instanceof UserMessage userMessage) { - List parts = new ArrayList<>(); - if (userMessage.getContent() != null) { - parts.add(Part.newBuilder().setText(userMessage.getContent()).build()); - } - - parts.addAll(mediaToParts(userMessage.getMedia())); - - return parts; - } - else if (message instanceof AssistantMessage assistantMessage) { - List parts = new ArrayList<>(); - if (StringUtils.hasText(assistantMessage.getContent())) { - parts.add(Part.newBuilder().setText(assistantMessage.getContent()).build()); - } - if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { - parts.addAll(assistantMessage.getToolCalls() - .stream() - .map(toolCall -> Part.newBuilder() - .setFunctionCall(FunctionCall.newBuilder() - .setName(toolCall.name()) - .setArgs(jsonToStruct(toolCall.arguments())) - .build()) - .build()) - .toList()); - } - return parts; - } - else if (message instanceof ToolResponseMessage toolResponseMessage) { - - return toolResponseMessage.getResponses() - .stream() - .map(response -> Part.newBuilder() - .setFunctionResponse(FunctionResponse.newBuilder() - .setName(response.name()) - .setResponse(jsonToStruct(response.responseData())) - .build()) - .build()) - .toList(); - } - else { - throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass()); - } - } - - private static List mediaToParts(Collection media) { - List parts = new ArrayList<>(); - - List mediaParts = media.stream() - .map(mediaData -> PartMaker.fromMimeTypeAndData(mediaData.getMimeType().toString(), mediaData.getData())) - .toList(); - - if (!CollectionUtils.isEmpty(mediaParts)) { - parts.addAll(mediaParts); - } - - return parts; - } - private List getFunctionTools(Set functionNames) { final var tool = Tool.newBuilder(); @@ -583,37 +573,6 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements return List.of(tool.build()); } - private static String structToJson(Struct struct) { - try { - return JsonFormat.printer().print(struct); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static Struct jsonToStruct(String json) { - try { - var structBuilder = Struct.newBuilder(); - JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder); - return structBuilder.build(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static Schema jsonToSchema(String json) { - try { - var schemaBuilder = Schema.newBuilder(); - JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder); - return schemaBuilder.build(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - /** * Generates the content response based on the provided Gemini request. Package * protected for testing purposes. @@ -651,4 +610,57 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements this.observationConvention = observationConvention; } + public enum GeminiMessageType { + + USER("user"), + + MODEL("model"); + + public final String value; + + GeminiMessageType(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + public enum ChatModel implements ChatModelDescription { + + /** + * Deprecated by Goolgle in favor of 1.5 pro and flash models. + */ + GEMINI_PRO_VISION("gemini-pro-vision"), + + GEMINI_PRO("gemini-pro"), + + GEMINI_1_5_PRO("gemini-1.5-pro-001"), + + GEMINI_1_5_FLASH("gemini-1.5-flash-001"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + @Override + public String getName() { + return this.value; + } + + } + + @JsonInclude(Include.NON_NULL) + public record GeminiRequest(List contents, GenerativeModel model) { + + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 908891645..574574a75 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.util.ArrayList; @@ -45,41 +46,43 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOp // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig - public enum TransportType { - - GRPC, REST - - } - - // @formatter:off /** * Optional. Stop sequences. */ private @JsonProperty("stopSequences") List stopSequences; + + // @formatter:off + /** * Optional. Controls the randomness of predictions. */ private @JsonProperty("temperature") Double temperature; + /** * Optional. If specified, nucleus sampling will be used. */ private @JsonProperty("topP") Double topP; + /** * Optional. If specified, top k sampling will be used. */ private @JsonProperty("topK") Float topK; + /** * Optional. The maximum number of tokens to generate. */ private @JsonProperty("candidateCount") Integer candidateCount; + /** * Optional. The maximum number of tokens to generate. */ private @JsonProperty("maxOutputTokens") Integer maxOutputTokens; + /** * Gemini model name. */ private @JsonProperty("modelName") String model; + /** * Optional. Output response mimetype of the generated candidate text. * - text/plain: (default) Text output. @@ -123,103 +126,29 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOp @JsonIgnore private Map toolContext; - // @formatter:on - public static Builder builder() { return new Builder(); } - public static class Builder { - - private VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); - - public Builder withStopSequences(List stopSequences) { - this.options.setStopSequences(stopSequences); - return this; - } - - public Builder withTemperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder withTopP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder withTopK(Float topK) { - this.options.setTopK(topK); - return this; - } - - public Builder withCandidateCount(Integer candidateCount) { - this.options.setCandidateCount(candidateCount); - return this; - } - - public Builder withMaxOutputTokens(Integer maxOutputTokens) { - this.options.setMaxOutputTokens(maxOutputTokens); - return this; - } - - public Builder withModel(String modelName) { - this.options.setModel(modelName); - return this; - } - - public Builder withModel(ChatModel model) { - this.options.setModel(model.getValue()); - return this; - } - - public Builder withResponseMimeType(String mimeType) { - Assert.notNull(mimeType, "mimeType must not be null"); - this.options.setResponseMimeType(mimeType); - return this; - } - - public Builder withFunctionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; - return this; - } - - public Builder withFunctions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; - return this; - } - - public Builder withFunction(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); - return this; - } - - public Builder withGoogleSearchRetrieval(boolean googleSearch) { - this.options.googleSearchRetrieval = googleSearch; - return this; - } - - public Builder withProxyToolCalls(boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; - return this; - } - - public Builder withToolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; - } - else { - this.options.toolContext.putAll(toolContext); - } - return this; - } - - public VertexAiGeminiChatOptions build() { - return this.options; - } + // @formatter:on + public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fromOptions) { + VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); + options.setStopSequences(fromOptions.getStopSequences()); + options.setTemperature(fromOptions.getTemperature()); + options.setTopP(fromOptions.getTopP()); + options.setTopK(fromOptions.getTopK()); + options.setCandidateCount(fromOptions.getCandidateCount()); + options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); + options.setModel(fromOptions.getModel()); + options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); + options.setResponseMimeType(fromOptions.getResponseMimeType()); + options.setFunctions(fromOptions.getFunctions()); + options.setResponseMimeType(fromOptions.getResponseMimeType()); + options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setProxyToolCalls(fromOptions.getProxyToolCalls()); + options.setToolContext(fromOptions.getToolContext()); + return options; } @Override @@ -364,33 +293,39 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOp @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof VertexAiGeminiChatOptions that)) + } + if (!(o instanceof VertexAiGeminiChatOptions that)) { return false; - return googleSearchRetrieval == that.googleSearchRetrieval && Objects.equals(stopSequences, that.stopSequences) - && Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP) - && Objects.equals(topK, that.topK) && Objects.equals(candidateCount, that.candidateCount) - && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) - && Objects.equals(responseMimeType, that.responseMimeType) - && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls) - && Objects.equals(toolContext, that.toolContext); + } + return this.googleSearchRetrieval == that.googleSearchRetrieval + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount) + && Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model) + && Objects.equals(this.responseMimeType, that.responseMimeType) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.functions, that.functions) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.toolContext, that.toolContext); } @Override public int hashCode() { - return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, - responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls, toolContext); + return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, + this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions, + this.googleSearchRetrieval, this.proxyToolCalls, this.toolContext); } @Override public String toString() { - return "VertexAiGeminiChatOptions{" + "stopSequences=" + stopSequences + ", temperature=" + temperature - + ", topP=" + topP + ", topK=" + topK + ", candidateCount=" + candidateCount + ", maxOutputTokens=" - + maxOutputTokens + ", model='" + model + '\'' + ", responseMimeType='" + responseMimeType + '\'' - + ", functionCallbacks=" + functionCallbacks + ", functions=" + functions + ", googleSearchRetrieval=" - + googleSearchRetrieval + '}'; + return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" + + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount=" + + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + + ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks=" + + this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval=" + + this.googleSearchRetrieval + '}'; } @Override @@ -398,23 +333,103 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOp return fromOptions(this); } - public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fromOptions) { - VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); - options.setStopSequences(fromOptions.getStopSequences()); - options.setTemperature(fromOptions.getTemperature()); - options.setTopP(fromOptions.getTopP()); - options.setTopK(fromOptions.getTopK()); - options.setCandidateCount(fromOptions.getCandidateCount()); - options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); - options.setModel(fromOptions.getModel()); - options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); - options.setResponseMimeType(fromOptions.getResponseMimeType()); - options.setFunctions(fromOptions.getFunctions()); - options.setResponseMimeType(fromOptions.getResponseMimeType()); - options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); - options.setProxyToolCalls(fromOptions.getProxyToolCalls()); - options.setToolContext(fromOptions.getToolContext()); - return options; + public enum TransportType { + + GRPC, REST + + } + + public static class Builder { + + private VertexAiGeminiChatOptions options = new VertexAiGeminiChatOptions(); + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Float topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withCandidateCount(Integer candidateCount) { + this.options.setCandidateCount(candidateCount); + return this; + } + + public Builder withMaxOutputTokens(Integer maxOutputTokens) { + this.options.setMaxOutputTokens(maxOutputTokens); + return this; + } + + public Builder withModel(String modelName) { + this.options.setModel(modelName); + return this; + } + + public Builder withModel(ChatModel model) { + this.options.setModel(model.getValue()); + return this; + } + + public Builder withResponseMimeType(String mimeType) { + Assert.notNull(mimeType, "mimeType must not be null"); + this.options.setResponseMimeType(mimeType); + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public Builder withGoogleSearchRetrieval(boolean googleSearch) { + this.options.googleSearchRetrieval = googleSearch; + return this; + } + + public Builder withProxyToolCalls(boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public VertexAiGeminiChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java index 0a46b9f2f..fd3d04106 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.aot; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java index 4369a2e79..2d8b69f98 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java index a250b98e0..aeab57df7 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.metadata; import com.google.cloud.vertexai.api.GenerateContentResponse.UsageMetadata; @@ -23,7 +24,7 @@ import org.springframework.util.Assert; /** * @author Christian Tzolov * @since 0.8.1 - * + * */ public class VertexAiUsage implements Usage { @@ -36,12 +37,12 @@ public class VertexAiUsage implements Usage { @Override public Long getPromptTokens() { - return Long.valueOf(usageMetadata.getPromptTokenCount()); + return Long.valueOf(this.usageMetadata.getPromptTokenCount()); } @Override public Long getGenerationTokens() { - return Long.valueOf(usageMetadata.getCandidatesTokenCount()); + return Long.valueOf(this.usageMetadata.getCandidatesTokenCount()); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index acbe2672f..3925fe125 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,30 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vertexai.gemini; import java.net.MalformedURLException; import java.net.URL; import java.util.List; +import com.google.cloud.vertexai.VertexAI; +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.Part; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.model.Media; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.GeminiRequest; +import org.springframework.ai.vertexai.gemini.function.MockWeatherService; import org.springframework.util.MimeTypeUtils; -import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Content; -import com.google.cloud.vertexai.api.Part; -import org.springframework.ai.vertexai.gemini.function.MockWeatherService; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -50,7 +51,7 @@ public class CreateGeminiRequestTests { @Test public void createRequestWithChatOptions() { - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build()); GeminiRequest request = client.createGeminiRequest(new Prompt("Test message content"), null); @@ -81,7 +82,7 @@ public class CreateGeminiRequestTests { var userMessage = new UserMessage("User Message Text", List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("http://example.com")))); - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build()); GeminiRequest request = client.createGeminiRequest(new Prompt(List.of(systemMessage, userMessage)), null); @@ -110,7 +111,7 @@ public class CreateGeminiRequestTests { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder().withModel("DEFAULT_MODEL").build()); var request = client.createGeminiRequest(new Prompt("Test message content", @@ -141,7 +142,7 @@ public class CreateGeminiRequestTests { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withModel("DEFAULT_MODEL") .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) @@ -198,7 +199,7 @@ public class CreateGeminiRequestTests { @Test public void createRequestWithGenerationConfigOptions() { - var client = new VertexAiGeminiChatModel(vertexAI, + var client = new VertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withModel("DEFAULT_MODEL") .withTemperature(66.6) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java index c7db75aab..9ab82aa64 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,17 @@ package org.springframework.ai.vertexai.gemini; +import java.io.IOException; +import java.util.List; + import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.generativeai.GenerativeModel; + import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; -import java.util.List; - /** * @author Mark Pollack */ @@ -41,9 +42,9 @@ public class TestVertexAiGeminiChatModel extends VertexAiGeminiChatModel { @Override GenerateContentResponse getContentResponse(GeminiRequest request) { - if (mockGenerativeModel != null) { + if (this.mockGenerativeModel != null) { try { - return mockGenerativeModel.generateContent(request.contents()); + return this.mockGenerativeModel.generateContent(request.contents()); } catch (IOException e) { // Should not be thrown by testing class diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java index 093dd3ba9..e34963c59 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,17 @@ package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.stream.Collectors; +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -38,11 +41,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import com.google.cloud.vertexai.Transport; -import com.google.cloud.vertexai.VertexAI; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Soby Chacko @@ -60,7 +59,7 @@ public class VertexAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -76,7 +75,7 @@ public class VertexAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -98,7 +97,7 @@ public class VertexAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponse = chatModel.stream(prompt); + Flux chatResponse = this.chatModel.stream(prompt); List responses = chatResponse.collectList().block(); assertThat(responses).isNotEmpty(); assertThat(responses).hasSizeGreaterThan(1); @@ -118,7 +117,7 @@ public class VertexAiChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index b11c8f1ec..4d6a06444 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini; import java.io.IOException; @@ -28,18 +29,18 @@ import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.Media; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -66,19 +67,19 @@ class VertexAiGeminiChatModelIT { @Test void roleTest() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @Test void testMessageHistory() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), prompt.getInstructions().get(1), response.getResult().getOutput(), new UserMessage("Repeat the last assistant message."))); - response = chatModel.call(promptWithMessageHistory); + response = this.chatModel.call(promptWithMessageHistory); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -86,7 +87,7 @@ class VertexAiGeminiChatModelIT { @Test void googleSearchTool() { Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().withGoogleSearchRetrieval(true).build()); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } @@ -96,7 +97,7 @@ class VertexAiGeminiChatModelIT { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), chatOptions); return prompt; @@ -133,16 +134,13 @@ class VertexAiGeminiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -156,7 +154,7 @@ class VertexAiGeminiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -166,7 +164,8 @@ class VertexAiGeminiChatModelIT { @Test void textStream() { - String generationTextFromStream = chatModel.stream(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.")) + String generationTextFromStream = this.chatModel + .stream(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.")) .collectList() .block() .stream() @@ -194,7 +193,7 @@ class VertexAiGeminiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -218,7 +217,7 @@ class VertexAiGeminiChatModelIT { var userMessage = new UserMessage("Explain what do you see o this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, data))); - var response = chatModel.call(new Prompt(List.of(userMessage))); + var response = this.chatModel.call(new Prompt(List.of(userMessage))); // Response should contain something like: // I see a bunch of bananas in a golden basket. The bananas are ripe and yellow. @@ -247,6 +246,10 @@ class VertexAiGeminiChatModelIT { // https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/intro_multimodal_use_cases.ipynb } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index 387d25501..bddb9328a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,9 @@ package org.springframework.ai.vertexai.gemini; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.*; +import java.io.IOException; +import java.util.Collections; +import java.util.List; import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.Candidate; @@ -32,6 +31,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.retry.RetryUtils; @@ -41,11 +41,10 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; /** * @author Mark Pollack @@ -54,25 +53,6 @@ import static org.assertj.core.api.Assertions.assertThat; @ExtendWith(MockitoExtension.class) public class VertexAiGeminiRetryTests { - private static class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -87,19 +67,19 @@ public class VertexAiGeminiRetryTests { @BeforeEach public void setUp() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new TestVertexAiGeminiChatModel(vertexAI, + this.chatModel = new TestVertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() .withTemperature(0.7) .withTopP(1.0) .withModel(VertexAiGeminiChatModel.ChatModel.GEMINI_PRO.getValue()) .build(), - null, Collections.emptyList(), retryTemplate); + null, Collections.emptyList(), this.retryTemplate); - chatModel.setMockGenerativeModel(mockGenerativeModel); + this.chatModel.setMockGenerativeModel(this.mockGenerativeModel); } @Test @@ -111,29 +91,48 @@ public class VertexAiGeminiRetryTests { .build()) .build(); - when(mockGenerativeModel.generateContent(any(List.class))) + when(this.mockGenerativeModel.generateContent(any(List.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(mockedResponse); // Call the chat model - ChatResponse result = chatModel.call(new Prompt("test prompt")); + ChatResponse result = this.chatModel.call(new Prompt("test prompt")); // Assertions assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isEqualTo("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void vertexAiGeminiChatNonTransientError() throws Exception { // Set up the mock GenerativeModel to throw a non-transient RuntimeException - when(mockGenerativeModel.generateContent(any(List.class))) + when(this.mockGenerativeModel.generateContent(any(List.class))) .thenThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown when calling the chat model - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("test prompt"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("test prompt"))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java index a4aaf3988..88774e9e3 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.aot; import java.util.Set; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java index ff62411a9..a7f7521df 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.gemini.function; import java.util.function.Function; @@ -32,14 +33,22 @@ public class MockWeatherService implements Function response = chatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); String responseString = response.collectList() .block() @@ -203,16 +203,18 @@ public class VertexAiGeminiChatModelFunctionCallingIT { .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", responseString); + this.logger.info("Response: {}", responseString); assertThat(responseString).contains("30", "10", "15"); } public record PaymentInfoRequest(String id) { + } public record TransactionStatus(String status) { + } public static class PaymentStatus implements Function { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 7f6ec1d3a..4d674c8bd 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,23 +16,25 @@ package org.springframework.ai.vertexai.gemini.function; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -44,10 +46,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Description; -import com.google.cloud.vertexai.Transport; -import com.google.cloud.vertexai.VertexAI; - -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,51 +58,12 @@ public class VertexAiGeminiPaymentTransactionIT { private final static Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionIT.class); + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + @Autowired ChatClient chatClient; - record TransactionStatusResponse(String id, String status) { - } - - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - @Test public void paymentStatuses() { // @formatter:off @@ -149,6 +109,49 @@ public class VertexAiGeminiPaymentTransactionIT { } } + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var response = chain.nextAroundCall(before(advisedRequest)); + observeAfter(response); + return response; + } + + private AdvisedRequest before(AdvisedRequest request) { + this.logger.info("System text: \n" + request.systemText()); + this.logger.info("System params: " + request.systemParams()); + this.logger.info("User text: \n" + request.userText()); + this.logger.info("User params:" + request.userParams()); + this.logger.info("Function names: " + request.functionNames()); + + this.logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + this.logger.info("Response: " + advisedResponse.response()); + } + + } + record Transaction(String id) { } @@ -161,9 +164,6 @@ public class VertexAiGeminiPaymentTransactionIT { record Statuses(List statuses) { } - private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), - new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); - @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-palm2/pom.xml b/models/spring-ai-vertex-ai-palm2/pom.xml index 2c5123b3c..07455d51d 100644 --- a/models/spring-ai-vertex-ai-palm2/pom.xml +++ b/models/spring-ai-vertex-ai-palm2/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java index de8d73263..b68e054a4 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageRequest; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; -import org.springframework.ai.chat.messages.MessageType; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java index 77b968f69..34a63ccb7 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -22,8 +25,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; -import java.util.List; - /** * @author Christian Tzolov * @author Thomas Vitale @@ -66,34 +67,13 @@ public class VertexAiPaLm2ChatOptions implements ChatOptions { return new Builder(); } - public static class Builder { - - private VertexAiPaLm2ChatOptions options = new VertexAiPaLm2ChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withCandidateCount(Integer candidateCount) { - this.options.candidateCount = candidateCount; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.options.topK = topK; - return this; - } - - public VertexAiPaLm2ChatOptions build() { - return this.options; - } - + public static VertexAiPaLm2ChatOptions fromOptions(VertexAiPaLm2ChatOptions fromOptions) { + return VertexAiPaLm2ChatOptions.builder() + .withTemperature(fromOptions.getTemperature()) + .withCandidateCount(fromOptions.getCandidateCount()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .build(); } @Override @@ -166,13 +146,34 @@ public class VertexAiPaLm2ChatOptions implements ChatOptions { return fromOptions(this); } - public static VertexAiPaLm2ChatOptions fromOptions(VertexAiPaLm2ChatOptions fromOptions) { - return VertexAiPaLm2ChatOptions.builder() - .withTemperature(fromOptions.getTemperature()) - .withCandidateCount(fromOptions.getCandidateCount()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .build(); + public static class Builder { + + private VertexAiPaLm2ChatOptions options = new VertexAiPaLm2ChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withCandidateCount(Integer candidateCount) { + this.options.candidateCount = candidateCount; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.options.topK = topK; + return this; + } + + public VertexAiPaLm2ChatOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java index 3fcdb9357..3ece4c0bc 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java index 65a8952c6..8e9ea6707 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.aot; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java index 2ad7c827b..e1e46326d 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2Api.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.io.IOException; @@ -207,10 +208,6 @@ public class VertexAiPaLm2Api { return response != null ? response.embedding() : null; } - @JsonInclude(Include.NON_NULL) - record BatchEmbeddingResponse(List embeddings) { - } - /** * Generates a response from the model given an input. * @param texts List of texts to embed. @@ -294,6 +291,10 @@ public class VertexAiPaLm2Api { .body(Model.class); } + @JsonInclude(Include.NON_NULL) + record BatchEmbeddingResponse(List embeddings) { + } + /** * API error response. * @@ -375,12 +376,12 @@ public class VertexAiPaLm2Api { @Override public final int hashCode() { - return Arrays.hashCode(value); + return Arrays.hashCode(this.value); } @Override public final boolean equals(Object arg0) { - return Arrays.equals(value,((Embedding) arg0).value); + return Arrays.equals(this.value,((Embedding) arg0).value); } } diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java index 3a98497cb..21d07a608 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatGenerationClientIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.Arrays; @@ -22,10 +23,10 @@ import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -59,10 +60,10 @@ class VertexAiPaLm2ChatGenerationClientIT { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Bartholomew"); } @@ -98,16 +99,13 @@ class VertexAiPaLm2ChatGenerationClientIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - // @Test void beanOutputConverterRecords() { @@ -120,13 +118,17 @@ class VertexAiPaLm2ChatGenerationClientIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } + record ActorsFilmsRecord(String actor, List movies) { + + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java index 9e2d6726c..d2dc28024 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public class VertexAiPaLm2ChatRequestTests { @Test public void createRequestWithDefaultOptions() { - var request = chatModel.createRequest(new Prompt("Test message content")); + var request = this.chatModel.createRequest(new Prompt("Test message content")); assertThat(request.prompt().messages()).hasSize(1); @@ -55,7 +56,7 @@ public class VertexAiPaLm2ChatRequestTests { // .withCandidateCount(2) .build(); - var request = chatModel.createRequest(new Prompt("Test message content", promptOptions)); + var request = this.chatModel.createRequest(new Prompt("Test message content", promptOptions)); assertThat(request.prompt().messages()).hasSize(1); @@ -75,7 +76,7 @@ public class VertexAiPaLm2ChatRequestTests { .withTopP(0.6) .build(); - var request = chatModel.createRequest(new Prompt("Test message content", portablePromptOptions)); + var request = this.chatModel.createRequest(new Prompt("Test message content", portablePromptOptions)); assertThat(request.prompt().messages()).hasSize(1); diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java index 2e05dbfd1..221964fd1 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2EmbeddingModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2; import java.util.List; @@ -38,17 +39,17 @@ class VertexAiPaLm2EmbeddingModelIT { @Test void simpleEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -56,7 +57,7 @@ class VertexAiPaLm2EmbeddingModelIT { assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingModel.dimensions()).isEqualTo(768); + assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } @SpringBootConfiguration diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java index 5ed21d464..ae57c09a3 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java index d820d4952..40f854df0 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.util.List; @@ -47,7 +48,7 @@ public class VertexAiPaLm2ApiIT { GenerateMessageRequest request = new GenerateMessageRequest(prompt); - GenerateMessageResponse response = vertexAiPaLm2Api.generateMessage(request); + GenerateMessageResponse response = this.vertexAiPaLm2Api.generateMessage(request); assertThat(response).isNotNull(); @@ -66,7 +67,7 @@ public class VertexAiPaLm2ApiIT { var text = "Hello, how are you?"; - Embedding response = vertexAiPaLm2Api.embedText(text); + Embedding response = this.vertexAiPaLm2Api.embedText(text); assertThat(response).isNotNull(); assertThat(response.value()).hasSize(768); @@ -77,7 +78,7 @@ public class VertexAiPaLm2ApiIT { var text = List.of("Hello, how are you?", "I am fine, thank you!"); - List response = vertexAiPaLm2Api.batchEmbedText(text); + List response = this.vertexAiPaLm2Api.batchEmbedText(text); assertThat(response).isNotNull(); assertThat(response).hasSize(2); @@ -91,7 +92,7 @@ public class VertexAiPaLm2ApiIT { var text = "Hello, how are you?"; var prompt = new MessagePrompt(List.of(new VertexAiPaLm2Api.Message("0", text))); - int response = vertexAiPaLm2Api.countMessageTokens(prompt); + int response = this.vertexAiPaLm2Api.countMessageTokens(prompt); assertThat(response).isEqualTo(17); } @@ -99,14 +100,14 @@ public class VertexAiPaLm2ApiIT { @Test public void listModels() { - List response = vertexAiPaLm2Api.listModels(); + List response = this.vertexAiPaLm2Api.listModels(); assertThat(response).isNotNull(); assertThat(response).hasSizeGreaterThan(0); assertThat(response).contains("models/chat-bison-001", "models/text-bison-001", "models/embedding-gecko-001"); System.out.println(" - " + response.stream() - .map(vertexAiPaLm2Api::getModel) + .map(this.vertexAiPaLm2Api::getModel) .map(VertexAiPaLm2Api.Model::toString) .collect(Collectors.joining("\n - "))); } @@ -114,7 +115,7 @@ public class VertexAiPaLm2ApiIT { @Test public void getModel() { - VertexAiPaLm2Api.Model model = vertexAiPaLm2Api.getModel("models/chat-bison-001"); + VertexAiPaLm2Api.Model model = this.vertexAiPaLm2Api.getModel("models/chat-bison-001"); System.out.println(model); assertThat(model).isNotNull(); diff --git a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java index 2f58d9dd6..77bfe39a4 100644 --- a/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java +++ b/models/spring-ai-vertex-ai-palm2/src/test/java/org/springframework/ai/vertexai/palm2/api/VertexAiPaLm2ApiTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vertexai.palm2.api; import java.util.List; @@ -26,8 +27,8 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.Embedding; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageRequest; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse; -import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse.ContentFilter.BlockedReason; +import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -62,7 +63,7 @@ public class VertexAiPaLm2ApiTests { @AfterEach void resetMockServer() { - server.reset(); + this.server.reset(); } @Test @@ -76,18 +77,19 @@ public class VertexAiPaLm2ApiTests { List.of(new VertexAiPaLm2Api.Message("0", "I'm fine, thank you.")), List.of(new VertexAiPaLm2Api.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason"))); - server + this.server .expect(requestToUriTemplate("/models/{generative}:generateMessage?key={apiKey}", VertexAiPaLm2Api.DEFAULT_GENERATE_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(request))) - .andRespond(withSuccess(objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON)); + .andExpect(content().json(this.objectMapper.writeValueAsString(request))) + .andRespond( + withSuccess(this.objectMapper.writeValueAsString(expectedResponse), MediaType.APPLICATION_JSON)); - GenerateMessageResponse response = client.generateMessage(request); + GenerateMessageResponse response = this.client.generateMessage(request); assertThat(response).isEqualTo(expectedResponse); - server.verify(); + this.server.verify(); } @Test @@ -97,19 +99,19 @@ public class VertexAiPaLm2ApiTests { Embedding expectedEmbedding = new Embedding(new float[] { 0.1f, 0.2f, 0.3f }); - server + this.server .expect(requestToUriTemplate("/models/{generative}:embedText?key={apiKey}", VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text)))) - .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), + .andExpect(content().json(this.objectMapper.writeValueAsString(Map.of("text", text)))) + .andRespond(withSuccess(this.objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), MediaType.APPLICATION_JSON)); - Embedding embedding = client.embedText(text); + Embedding embedding = this.client.embedText(text); assertThat(embedding).isEqualTo(expectedEmbedding); - server.verify(); + this.server.verify(); } @Test @@ -120,19 +122,19 @@ public class VertexAiPaLm2ApiTests { List expectedEmbeddings = List.of(new Embedding(new float[] { 0.1f, 0.2f, 0.3f }), new Embedding(new float[] { 0.4f, 0.5f, 0.6f })); - server + this.server .expect(requestToUriTemplate("/models/{generative}:batchEmbedText?key={apiKey}", VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) - .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts)))) - .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)), + .andExpect(content().json(this.objectMapper.writeValueAsString(Map.of("texts", texts)))) + .andRespond(withSuccess(this.objectMapper.writeValueAsString(Map.of("embeddings", expectedEmbeddings)), MediaType.APPLICATION_JSON)); - List embeddings = client.batchEmbedText(texts); + List embeddings = this.client.batchEmbedText(texts); assertThat(embeddings).isEqualTo(expectedEmbeddings); - server.verify(); + this.server.verify(); } @SpringBootConfiguration diff --git a/models/spring-ai-watsonx-ai/pom.xml b/models/spring-ai-watsonx-ai/pom.xml index e5fc7df00..3fc195eb4 100644 --- a/models/spring-ai-watsonx-ai/pom.xml +++ b/models/spring-ai-watsonx-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java index 218ce4697..8b4e78b34 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.List; import java.util.Map; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.model.ChatModel; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -145,4 +146,4 @@ public class WatsonxAiChatModel implements ChatModel, StreamingChatModel { return WatsonxAiChatOptions.fromOptions(this.defaultOptions); } -} \ No newline at end of file +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index e32ba2d1b..9a113da57 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.HashMap; @@ -20,13 +21,14 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonAnyGetter; import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -44,6 +46,9 @@ import org.springframework.ai.chat.prompt.ChatOptions; public class WatsonxAiChatOptions implements ChatOptions { + @JsonIgnore + private final ObjectMapper mapper = new ObjectMapper(); + /** * The temperature of the model. Increasing the temperature will * make the model answer more creatively. (Default: 0.7) @@ -122,12 +127,41 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonProperty("additional") private Map additional = new HashMap<>(); - @JsonIgnore - private final ObjectMapper mapper = new ObjectMapper(); + public static Builder builder() { + return new Builder(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !e.getKey().equals("model")) + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { + return WatsonxAiChatOptions.builder() + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withDecodingMethod(fromOptions.getDecodingMethod()) + .withMaxNewTokens(fromOptions.getMaxNewTokens()) + .withMinNewTokens(fromOptions.getMinNewTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) + .withRandomSeed(fromOptions.getRandomSeed()) + .withModel(fromOptions.getModel()) + .withAdditionalProperties(fromOptions.getAdditionalProperties()) + .build(); + } @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -136,7 +170,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -145,7 +179,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -153,7 +187,7 @@ public class WatsonxAiChatOptions implements ChatOptions { } public String getDecodingMethod() { - return decodingMethod; + return this.decodingMethod; } public void setDecodingMethod(String decodingMethod) { @@ -172,7 +206,7 @@ public class WatsonxAiChatOptions implements ChatOptions { } public Integer getMaxNewTokens() { - return maxNewTokens; + return this.maxNewTokens; } public void setMaxNewTokens(Integer maxNewTokens) { @@ -180,7 +214,7 @@ public class WatsonxAiChatOptions implements ChatOptions { } public Integer getMinNewTokens() { - return minNewTokens; + return this.minNewTokens; } public void setMinNewTokens(Integer minNewTokens) { @@ -189,7 +223,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -208,7 +242,7 @@ public class WatsonxAiChatOptions implements ChatOptions { } public Double getRepetitionPenalty() { - return repetitionPenalty; + return this.repetitionPenalty; } public void setRepetitionPenalty(Double repetitionPenalty) { @@ -216,7 +250,7 @@ public class WatsonxAiChatOptions implements ChatOptions { } public Integer getRandomSeed() { - return randomSeed; + return this.randomSeed; } public void setRandomSeed(Integer randomSeed) { @@ -225,7 +259,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -234,7 +268,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonAnyGetter public Map getAdditionalProperties() { - return additional.entrySet().stream() + return this.additional.entrySet().stream() .collect(Collectors.toMap( entry -> toSnakeCase(entry.getKey()), Map.Entry::getValue @@ -243,7 +277,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonAnySetter public void addAdditionalProperty(String key, Object value) { - additional.put(key, value); + this.additional.put(key, value); } @Override @@ -252,9 +286,31 @@ public class WatsonxAiChatOptions implements ChatOptions { return null; } - public static Builder builder() { - return new Builder(); - } + /** + * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + try { + var json = this.mapper.writeValueAsString(this); + var map = this.mapper.readValue(json, new TypeReference>() {}); + map.remove("additional"); + + return map; + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private String toSnakeCase(String input) { + return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; + } + + @Override + public WatsonxAiChatOptions copy() { + return fromOptions(this); + } public static class Builder { @@ -325,59 +381,5 @@ public class WatsonxAiChatOptions implements ChatOptions { } } - /** - * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. - * @return The {@link Map} of key/value pairs. - */ - public Map toMap() { - try { - var json = mapper.writeValueAsString(this); - var map = mapper.readValue(json, new TypeReference>() {}); - map.remove("additional"); - - return map; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Filter out the non-supported fields from the options. - * @param options The options to filter. - * @return The filtered options. - */ - public static Map filterNonSupportedFields(Map options) { - return options.entrySet().stream() - .filter(e -> !e.getKey().equals("model")) - .filter(e -> e.getValue() != null) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - private String toSnakeCase(String input) { - return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; - } - - @Override - public WatsonxAiChatOptions copy() { - return fromOptions(this); - } - - public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { - return WatsonxAiChatOptions.builder() - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withDecodingMethod(fromOptions.getDecodingMethod()) - .withMaxNewTokens(fromOptions.getMaxNewTokens()) - .withMinNewTokens(fromOptions.getMinNewTokens()) - .withStopSequences(fromOptions.getStopSequences()) - .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) - .withRandomSeed(fromOptions.getRandomSeed()) - .withModel(fromOptions.getModel()) - .withAdditionalProperties(fromOptions.getAdditionalProperties()) - .build(); - } - } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java index 5b3e03ea1..18e3ae361 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java @@ -1,18 +1,40 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.watsonx.api.WatsonxAiApi; import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest; import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - /** * {@link EmbeddingModel} implementation for {@literal Watsonx.ai}. *

diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java index 9db6b6dd5..ab1622fea 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java @@ -1,8 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -22,25 +39,6 @@ public class WatsonxAiEmbeddingOptions implements EmbeddingOptions { @JsonProperty("model_id") private String model; - public WatsonxAiEmbeddingOptions withModel(String model) { - this.model = model; - return this; - } - - public String getModel() { - return model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - /** * Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance. * @return A new {@link WatsonxAiEmbeddingOptions} instance. @@ -53,4 +51,23 @@ public class WatsonxAiEmbeddingOptions implements EmbeddingOptions { return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel()); } + public WatsonxAiEmbeddingOptions withModel(String model) { + this.model = model; + return this; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java index c76470a7a..b78266e7a 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.aot; import org.springframework.ai.watsonx.WatsonxAiChatOptions; diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java index 2de2f36fd..7953f8c56 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import java.util.List; @@ -23,14 +24,14 @@ import com.ibm.cloud.sdk.core.security.IamAuthenticator; import com.ibm.cloud.sdk.core.security.IamToken; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.retry.annotation.Backoff; -import org.springframework.retry.annotation.Retryable; import reactor.core.publisher.Flux; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.retry.annotation.Backoff; +import org.springframework.retry.annotation.Retryable; import org.springframework.util.Assert; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -45,8 +46,10 @@ import org.springframework.web.reactive.function.client.WebClient; // @formatter:off public class WatsonxAiApi { - private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); public static final String WATSONX_REQUEST_CANNOT_BE_NULL = "Watsonx Request cannot be null"; + + private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); + private final RestClient restClient; private final WebClient webClient; private final IamAuthenticator iamAuthenticator; @@ -108,7 +111,7 @@ public class WatsonxAiApi { return this.restClient.post() .uri(this.textEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(watsonxAiChatRequest.withProjectId(projectId)) + .body(watsonxAiChatRequest.withProjectId(this.projectId)) .retrieve() .toEntity(WatsonxAiChatResponse.class); } @@ -146,7 +149,7 @@ public class WatsonxAiApi { return this.restClient.post() .uri(this.embeddingEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(request.withProjectId(projectId)) + .body(request.withProjectId(this.projectId)) .retrieve() .toEntity(WatsonxAiEmbeddingResponse.class); } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java index c228372cb..817e9802f 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import java.util.Map; @@ -49,19 +50,18 @@ public class WatsonxAiChatRequest { this.projectId = projectId; } + public static Builder builder(String input) { return new Builder(input); } + public WatsonxAiChatRequest withProjectId(String projectId) { this.projectId = projectId; return this; } - public String getInput() { return input; } + public String getInput() { return this.input; } - public Map getParameters() { return parameters; } + public Map getParameters() { return this.parameters; } - public String getModelId() { return modelId; } - - - public static Builder builder(String input) { return new Builder(input); } + public String getModelId() { return this.modelId; } public static class Builder { public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required"; @@ -81,7 +81,7 @@ public class WatsonxAiChatRequest { } public WatsonxAiChatRequest build() { - return new WatsonxAiChatRequest(input, parameters, model, ""); + return new WatsonxAiChatRequest(this.input, this.parameters, this.model, ""); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java index 36127771b..f90ce6436 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.watsonx.api; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; +package org.springframework.ai.watsonx.api; import java.util.Date; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + /** * Java class for Watsonx.ai Chat Response object. * diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java index 316f8c2e4..ecb67f8d9 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java index 331dfa0a1..8e8da278d 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java @@ -1,10 +1,27 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx.api; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; -import java.util.List; +import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; /** * Java class for Watsonx.ai Embedding Request object. @@ -24,35 +41,35 @@ public class WatsonxAiEmbeddingRequest { @JsonProperty("project_id") String projectId; - public String getModel() { - return model; - } - - public List getInputs() { - return inputs; - } - private WatsonxAiEmbeddingRequest(String model, List inputs, String projectId) { this.model = model; this.inputs = inputs; this.projectId = projectId; } + public static Builder builder(List inputs) { + return new Builder(inputs); + } + + public String getModel() { + return this.model; + } + + public List getInputs() { + return this.inputs; + } + public WatsonxAiEmbeddingRequest withProjectId(String projectId) { this.projectId = projectId; return this; } - public static Builder builder(List inputs) { - return new Builder(inputs); - } - public static class Builder { - private String model = WatsonxAiEmbeddingOptions.DEFAULT_MODEL; - private final List inputs; + private String model = WatsonxAiEmbeddingOptions.DEFAULT_MODEL; + public Builder(List inputs) { this.inputs = inputs; } @@ -63,7 +80,7 @@ public class WatsonxAiEmbeddingRequest { } public WatsonxAiEmbeddingRequest build() { - return new WatsonxAiEmbeddingRequest(model, inputs, ""); + return new WatsonxAiEmbeddingRequest(this.model, this.inputs, ""); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java index ec1ae0226..a2284afee 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java @@ -1,11 +1,27 @@ -package org.springframework.ai.watsonx.api; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; +package org.springframework.ai.watsonx.api; import java.util.Date; import java.util.List; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + /** * Java class for Watsonx.ai Embedding Response object. * @@ -16,4 +32,5 @@ import java.util.List; public record WatsonxAiEmbeddingResponse(@JsonProperty("model_id") String model, @JsonProperty("created_at") Date createdAt, @JsonProperty("results") List results, @JsonProperty("input_token_count") Integer inputTokenCount) { + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java index 975a1195e..a86dd12a2 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java @@ -1,10 +1,24 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx.api; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.List; - /** * Java class for Watsonx.ai Embedding Results object. * @@ -13,4 +27,5 @@ import java.util.List; */ @JsonInclude(JsonInclude.Include.NON_NULL) public record WatsonxAiEmbeddingResults(@JsonProperty("embedding") float[] embedding) { + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java index 75be3e173..449ec8f73 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.watsonx.utils; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; +package org.springframework.ai.watsonx.utils; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; + // @formatter:off public class MessageToPromptConverter { - private static final String HUMAN_PROMPT = "Human: "; - private static final String ASSISTANT_PROMPT = "Assistant: "; public static final String TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS = "Tool execution results are not supported for watsonx.ai models"; + + private static final String HUMAN_PROMPT = "Human: "; + + private static final String ASSISTANT_PROMPT = "Assistant: "; + private String humanPrompt = HUMAN_PROMPT; private String assistantPrompt = ASSISTANT_PROMPT; @@ -60,7 +64,7 @@ public class MessageToPromptConverter { .map(this::messageToString) .collect(Collectors.joining("\n")); - return String.format("%s%n%n%s%n%s", systemMessages, userMessages, assistantPrompt).trim(); + return String.format("%s%n%n%s%n%s", systemMessages, userMessages, this.assistantPrompt).trim(); } protected String messageToString(Message message) { @@ -68,9 +72,9 @@ public class MessageToPromptConverter { case SYSTEM: return message.getContent(); case USER: - return humanPrompt + message.getContent(); + return this.humanPrompt + message.getContent(); case ASSISTANT: - return assistantPrompt + message.getContent(); + return this.assistantPrompt + message.getContent(); case TOOL: throw new IllegalArgumentException(TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS); } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index 313e0bc25..4a41f72f7 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx; import java.util.Date; @@ -25,10 +26,10 @@ import org.junit.Test; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.watsonx.api.WatsonxAiApi; @@ -57,7 +58,7 @@ public class WatsonxAiChatModelTest { Prompt prompt = new Prompt("Test message", options); Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> { - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); }); } @@ -71,7 +72,7 @@ public class WatsonxAiChatModelTest { .build(); Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("greedy"); @@ -105,7 +106,7 @@ public class WatsonxAiChatModelTest { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("sample"); @@ -139,7 +140,7 @@ public class WatsonxAiChatModelTest { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiChatRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = this.chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getInput()).isEqualTo(msg); diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java index 42e6c0cc5..4e19920ec 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java @@ -1,6 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx; +import java.util.Date; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.watsonx.api.WatsonxAiApi; @@ -9,9 +29,6 @@ import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResults; import org.springframework.http.ResponseEntity; -import java.util.Date; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -19,13 +36,13 @@ import static org.mockito.Mockito.when; public class WatsonxAiEmbeddingModelTest { - private WatsonxAiApi watsonxAiApiMock; - private final WatsonxAiEmbeddingModel embeddingModel; + private WatsonxAiApi watsonxAiApiMock; + public WatsonxAiEmbeddingModelTest() { this.watsonxAiApiMock = mock(WatsonxAiApi.class); - this.embeddingModel = new WatsonxAiEmbeddingModel(watsonxAiApiMock); + this.embeddingModel = new WatsonxAiEmbeddingModel(this.watsonxAiApiMock); } @Test @@ -34,7 +51,7 @@ public class WatsonxAiEmbeddingModelTest { List inputs = List.of("test"); WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, options); assertThat(request.getModel()).isEqualTo(MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -46,7 +63,7 @@ public class WatsonxAiEmbeddingModelTest { List inputs = List.of("test"); WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, options); assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -55,7 +72,8 @@ public class WatsonxAiEmbeddingModelTest { @Test void createRequestWithNoOptions() { List inputs = List.of("test"); - WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, EmbeddingOptions.EMPTY); + WatsonxAiEmbeddingRequest request = this.embeddingModel.watsonxAiEmbeddingRequest(inputs, + EmbeddingOptions.EMPTY); assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); assertThat(request.getInputs().size()).isEqualTo(inputs.size()); @@ -73,14 +91,14 @@ public class WatsonxAiEmbeddingModelTest { inputTokenCount); ResponseEntity mockResponseEntity = ResponseEntity.ok(mockResponse); - when(watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); + when(this.watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); - assertThat(embeddingModel.dimensions()).isEqualTo(2); + assertThat(this.embeddingModel.dimensions()).isEqualTo(2); } } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java index fddaba9e5..7d82cd7a4 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.aot; +import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.watsonx.WatsonxAiChatOptions; @@ -22,8 +25,6 @@ import org.springframework.ai.watsonx.api.WatsonxAiApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java index f77812852..ac71fe43e 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.api; -import static org.assertj.core.api.Assertions.assertThat; +import java.util.List; +import java.util.Map; import org.junit.Test; import org.springframework.ai.watsonx.WatsonxAiChatOptions; -import java.util.List; -import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Pablo Sanchidrian Herrera diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java index 98d63092f..f5de75874 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java @@ -1,10 +1,27 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.watsonx.api; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Pablo Sanchidrian Herrera */ diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java index 5d22477c6..400541312 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/utils/MessageToPromptConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.watsonx.utils; +import java.util.List; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.jupiter.api.Disabled; -import org.springframework.ai.chat.messages.Message; + import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; -import java.util.List; - /** * @author Pablo Sanchidrian Herrera * @author John Jairo Moreno Rojas @@ -36,64 +37,64 @@ public class MessageToPromptConverterTest { @Before public void setUp() { - converter = MessageToPromptConverter.create().withHumanPrompt("").withAssistantPrompt(""); + this.converter = MessageToPromptConverter.create().withHumanPrompt("").withAssistantPrompt(""); } @Test public void testSingleUserMessage() { Message userMessage = new UserMessage("User message"); String expected = "User message"; - Assert.assertEquals(expected, converter.messageToString(userMessage)); + Assert.assertEquals(expected, this.converter.messageToString(userMessage)); } @Test public void testSingleAssistantMessage() { Message assistantMessage = new AssistantMessage("Assistant message"); String expected = "Assistant message"; - Assert.assertEquals(expected, converter.messageToString(assistantMessage)); + Assert.assertEquals(expected, this.converter.messageToString(assistantMessage)); } @Test public void testSystemMessageType() { Message systemMessage = new SystemMessage("System message"); String expected = "System message"; - Assert.assertEquals(expected, converter.messageToString(systemMessage)); + Assert.assertEquals(expected, this.converter.messageToString(systemMessage)); } @Test public void testCustomHumanPrompt() { - converter.withHumanPrompt("Custom Human: "); + this.converter.withHumanPrompt("Custom Human: "); Message userMessage = new UserMessage("User message"); String expected = "Custom Human: User message"; - Assert.assertEquals(expected, converter.messageToString(userMessage)); + Assert.assertEquals(expected, this.converter.messageToString(userMessage)); } @Test public void testCustomAssistantPrompt() { - converter.withAssistantPrompt("Custom Assistant: "); + this.converter.withAssistantPrompt("Custom Assistant: "); Message assistantMessage = new AssistantMessage("Assistant message"); String expected = "Custom Assistant: Assistant message"; - Assert.assertEquals(expected, converter.messageToString(assistantMessage)); + Assert.assertEquals(expected, this.converter.messageToString(assistantMessage)); } @Test public void testEmptyMessageList() { String expected = ""; - Assert.assertEquals(expected, converter.toPrompt(List.of())); + Assert.assertEquals(expected, this.converter.toPrompt(List.of())); } @Test public void testSystemMessageList() { String msg = "this is a LLM prompt"; SystemMessage message = new SystemMessage(msg); - Assert.assertEquals(msg, converter.toPrompt(List.of(message))); + Assert.assertEquals(msg, this.converter.toPrompt(List.of(message))); } @Test public void testUserMessageList() { List messages = List.of(new UserMessage("User message")); String expected = "User message"; - Assert.assertEquals(expected, converter.toPrompt(messages)); + Assert.assertEquals(expected, this.converter.toPrompt(messages)); } } diff --git a/models/spring-ai-zhipuai/pom.xml b/models/spring-ai-zhipuai/pom.xml index 4c2c6e179..59df1857b 100644 --- a/models/spring-ai-zhipuai/pom.xml +++ b/models/spring-ai-zhipuai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 9a13668d6..b6b8e0659 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -64,16 +76,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal ZhiPuAI} @@ -91,16 +93,16 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - /** - * The default options used for the chat completion requests. - */ - private final ZhiPuAiChatOptions defaultOptions; - /** * The retry template used to retry the ZhiPuAI API calls. */ public final RetryTemplate retryTemplate; + /** + * The default options used for the chat completion requests. + */ + private final ZhiPuAiChatOptions defaultOptions; + /** * Low-level access to the ZhiPuAI API. */ @@ -176,6 +178,21 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod this.observationRegistry = observationRegistry; } + private static Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); + + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } + @Override public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); @@ -318,21 +335,6 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod .build(); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", - toolCall.function().name(), toolCall.function().arguments())) - .toList(); - - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); - return new Generation(assistantMessage, generationMetadata); - } - /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index e30ab9666..c0c66253a 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -27,12 +35,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - /** * ZhiPuAiChatOptions represents the options for the ZhiPuAiChat model. * @@ -137,6 +139,309 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { return new Builder(); } + public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { + return ZhiPuAiChatOptions.builder() + .withModel(fromOptions.getModel()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withUser(fromOptions.getUser()) + .withRequestId(fromOptions.getRequestId()) + .withDoSample(fromOptions.getDoSample()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + public String getRequestId() { + return this.requestId; + } + + public void setRequestId(String requestId) { + this.requestId = requestId; + } + + public Boolean getDoSample() { + return this.doSample; + } + + public void setDoSample(Boolean doSample) { + this.doSample = doSample; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + @JsonIgnore + public Double getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Double getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); + result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); + result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); + result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); + result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.maxTokens == null) { + if (other.maxTokens != null) { + return false; + } + } + else if (!this.maxTokens.equals(other.maxTokens)) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + if (this.tools == null) { + if (other.tools != null) { + return false; + } + } + else if (!this.tools.equals(other.tools)) { + return false; + } + if (this.toolChoice == null) { + if (other.toolChoice != null) { + return false; + } + } + else if (!this.toolChoice.equals(other.toolChoice)) { + return false; + } + if (this.user == null) { + if (other.user != null) { + return false; + } + } + else if (!this.user.equals(other.user)) { + return false; + } + if (this.requestId == null) { + if (other.requestId != null) { + return false; + } + } + else if (!this.requestId.equals(other.requestId)) { + return false; + } + if (this.doSample == null) { + if (other.doSample != null) { + return false; + } + } + else if (!this.doSample.equals(other.doSample)) { + return false; + } + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) { + return false; + } + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { + return false; + } + if (this.toolContext == null) { + if (other.toolContext != null) { + return false; + } + } + else if (!this.toolContext.equals(other.toolContext)) { + return false; + } + return true; + } + + @Override + public ZhiPuAiChatOptions copy() { + return fromOptions(this); + } + public static class Builder { protected ZhiPuAiChatOptions options; @@ -237,280 +542,4 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - - public void setMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - } - - @Override - @JsonIgnore - public List getStopSequences() { - return getStop(); - } - - @JsonIgnore - public void setStopSequences(List stopSequences) { - setStop(stopSequences); - } - - public List getStop() { - return this.stop; - } - - public void setStop(List stop) { - this.stop = stop; - } - - @Override - public Double getTemperature() { - return this.temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - @Override - public Double getTopP() { - return this.topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public List getTools() { - return this.tools; - } - - public void setTools(List tools) { - this.tools = tools; - } - - public String getToolChoice() { - return this.toolChoice; - } - - public void setToolChoice(String toolChoice) { - this.toolChoice = toolChoice; - } - - public String getUser() { - return this.user; - } - - public void setUser(String user) { - this.user = user; - } - - public String getRequestId() { - return requestId; - } - - public void setRequestId(String requestId) { - this.requestId = requestId; - } - - public Boolean getDoSample() { - return doSample; - } - - public void setDoSample(Boolean doSample) { - this.doSample = doSample; - } - - @Override - public List getFunctionCallbacks() { - return this.functionCallbacks; - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; - } - - @Override - public Set getFunctions() { - return functions; - } - - public void setFunctions(Set functionNames) { - this.functions = functionNames; - } - - @Override - @JsonIgnore - public Double getFrequencyPenalty() { - return null; - } - - @Override - @JsonIgnore - public Double getPresencePenalty() { - return null; - } - - @Override - @JsonIgnore - public Integer getTopK() { - return null; - } - - @Override - public Boolean getProxyToolCalls() { - return this.proxyToolCalls; - } - - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; - } - - @Override - public Map getToolContext() { - return this.toolContext; - } - - @Override - public void setToolContext(Map toolContext) { - this.toolContext = toolContext; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.maxTokens == null) { - if (other.maxTokens != null) - return false; - } - else if (!this.maxTokens.equals(other.maxTokens)) - return false; - if (this.stop == null) { - if (other.stop != null) - return false; - } - else if (!stop.equals(other.stop)) - return false; - if (this.temperature == null) { - if (other.temperature != null) - return false; - } - else if (!this.temperature.equals(other.temperature)) - return false; - if (this.topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (this.tools == null) { - if (other.tools != null) - return false; - } - else if (!tools.equals(other.tools)) - return false; - if (this.toolChoice == null) { - if (other.toolChoice != null) - return false; - } - else if (!toolChoice.equals(other.toolChoice)) - return false; - if (this.user == null) { - if (other.user != null) - return false; - } - else if (!this.user.equals(other.user)) - return false; - if (this.requestId == null) { - if (other.requestId != null) - return false; - } - else if (!this.requestId.equals(other.requestId)) - return false; - if (this.doSample == null) { - if (other.doSample != null) - return false; - } - else if (!this.doSample.equals(other.doSample)) - return false; - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) - return false; - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) - return false; - if (this.toolContext == null) { - if (other.toolContext != null) - return false; - } - else if (!this.toolContext.equals(other.toolContext)) - return false; - return true; - } - - @Override - public ZhiPuAiChatOptions copy() { - return fromOptions(this); - } - - public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { - return ZhiPuAiChatOptions.builder() - .withModel(fromOptions.getModel()) - .withMaxTokens(fromOptions.getMaxTokens()) - .withStop(fromOptions.getStop()) - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTools(fromOptions.getTools()) - .withToolChoice(fromOptions.getToolChoice()) - .withUser(fromOptions.getUser()) - .withRequestId(fromOptions.getRequestId()) - .withDoSample(fromOptions.getDoSample()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) - .withFunctions(fromOptions.getFunctions()) - .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withToolContext(fromOptions.getToolContext()) - .build(); - } - } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index d33035420..214679c32 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -39,10 +45,6 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - /** * ZhiPuAI Embedding Model implementation. * diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java index cbd75ad4e..02119d53c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -42,6 +44,21 @@ public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions { return new Builder(); } + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + public static class Builder { protected ZhiPuAiEmbeddingOptions options; @@ -61,19 +78,4 @@ public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions { } - @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - @JsonIgnore - public Integer getDimensions() { - return null; - } - } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java index f7464cb79..cb267fd2f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageModel; @@ -30,8 +34,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.util.List; - /** * ZhiPuAiImageModel is a class that implements the ImageModel interface. It provides a * client for calling the ZhiPuAI image generation API. @@ -43,12 +45,12 @@ public class ZhiPuAiImageModel implements ImageModel { private final static Logger logger = LoggerFactory.getLogger(ZhiPuAiImageModel.class); + public final RetryTemplate retryTemplate; + private final ZhiPuAiImageOptions defaultOptions; private final ZhiPuAiImageApi zhiPuAiImageApi; - public final RetryTemplate retryTemplate; - public ZhiPuAiImageModel(ZhiPuAiImageApi zhiPuAiImageApi) { this(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index a6d1de316..baa1e8475 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.image.ImageOptions; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; -import java.util.Objects; - /** * ZhiPuAiImageOptions represents the options for image generation using ZhiPuAI image * model. @@ -64,30 +66,6 @@ public class ZhiPuAiImageOptions implements ImageOptions { return new Builder(); } - public static class Builder { - - private final ZhiPuAiImageOptions options; - - private Builder() { - this.options = new ZhiPuAiImageOptions(); - } - - public Builder withModel(String model) { - options.setModel(model); - return this; - } - - public Builder withUser(String user) { - options.setUser(user); - return this; - } - - public ZhiPuAiImageOptions build() { - return options; - } - - } - @Override @JsonIgnore public Integer getN() { @@ -137,21 +115,47 @@ public class ZhiPuAiImageOptions implements ImageOptions { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ZhiPuAiImageOptions that)) + } + if (!(o instanceof ZhiPuAiImageOptions that)) { return false; - return Objects.equals(model, that.model) && Objects.equals(user, that.user); + } + return Objects.equals(this.model, that.model) && Objects.equals(this.user, that.user); } @Override public int hashCode() { - return Objects.hash(model, user); + return Objects.hash(this.model, this.user); } @Override public String toString() { - return "ZhiPuAiImageOptions{model='" + model + '\'' + ", user='" + user + '\'' + '}'; + return "ZhiPuAiImageOptions{model='" + this.model + '\'' + ", user='" + this.user + '\'' + '}'; + } + + public static class Builder { + + private final ZhiPuAiImageOptions options; + + private Builder() { + this.options = new ZhiPuAiImageOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public ZhiPuAiImageOptions build() { + return this.options; + } + } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java index 51185977a..d5cc2f21e 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.aot; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 0758d2e8f..2be99709d 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; import java.util.Arrays; @@ -23,6 +24,12 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; @@ -37,13 +44,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - // @formatter:off /** * Single class implementation of the ZhiPuAI Chat Completion API and @@ -63,6 +63,8 @@ public class ZhiPuAiApi { private final WebClient webClient; + private final ZhiPuAiStreamFunctionCallingHelper chunkMerger = new ZhiPuAiStreamFunctionCallingHelper(); + /** * Create a new chat completion api with default base URL. * @@ -120,6 +122,111 @@ public class ZhiPuAiApi { .build(); } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v4/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v4/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), + this.chunkMerger::merge); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + /** + * Creates an embedding vector representing the input text or token array. + * + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. + * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or + * {@link List} of {@link List} of tokens. For example: + * + *

{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); + Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer + || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri("/v4/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + /** * ZhiPuAI Chat Completion Models: * ZhiPuAI Model. @@ -139,7 +246,7 @@ public class ZhiPuAiApi { } public String getValue() { - return value; + return this.value; } @Override @@ -148,6 +255,58 @@ public class ZhiPuAiApi { } } + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") TOOL_CALL + } + + /** + * ZhiPuAI Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 1024 + */ + Embedding_2("Embedding-2"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + } + /** * Represents a tool the model may call. Currently, only functions are supported as a tool. * @@ -355,6 +514,15 @@ public class ZhiPuAiApi { @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls) { + /** + * Create a chat completion message with the given content and role. All other fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + /** * Get message content as String. */ @@ -368,15 +536,6 @@ public class ZhiPuAiApi { throw new IllegalStateException("The content is not a string!"); } - /** - * Create a chat completion message with the given content and role. All other fields are null. - * @param content The contents of the message. - * @param role The role of the author of this message. - */ - public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); - } - /** * The role of the author of this message. */ @@ -415,22 +574,6 @@ public class ZhiPuAiApi { @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl) { - /** - * @param url Either a URL of the image or the base64 encoded image data. - * The base64 encoded image data must have a special prefix in the following format: - * "data:{mimetype};base64,{base64-encoded-image-data}". - * @param detail Specifies the detail level of the image. - */ - @JsonInclude(Include.NON_NULL) - public record ImageUrl( - @JsonProperty("url") String url, - @JsonProperty("detail") String detail) { - - public ImageUrl(String url) { - this(url, null); - } - } - /** * Shortcut constructor for a text content. * @param text The text content of the message. @@ -446,6 +589,22 @@ public class ZhiPuAiApi { public MediaContent(ImageUrl imageUrl) { this("image_url", null, imageUrl); } + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } } /** * The relevant tool call. @@ -475,43 +634,6 @@ public class ZhiPuAiApi { } } - public static String getTextContent(List content) { - return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); - } - - /** - * The reason the model stopped generating tokens. - */ - public enum ChatCompletionFinishReason { - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("content_filter") CONTENT_FILTER, - /** - * The model called a tool. - */ - @JsonProperty("tool_calls") TOOL_CALLS, - /** - * (deprecated) The model called a function. - */ - @JsonProperty("function_call") FUNCTION_CALL, - /** - * Only for compatibility with Mistral AI API. - */ - @JsonProperty("tool_call") TOOL_CALL - } - /** * Represents a chat completion response returned by model, based on the provided input. * @@ -655,91 +777,6 @@ public class ZhiPuAiApi { } } - /** - * Creates a model response for the given chat conversation. - * - * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. - */ - public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); - - return this.restClient.post() - .uri("/v4/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); - } - - private final ZhiPuAiStreamFunctionCallingHelper chunkMerger = new ZhiPuAiStreamFunctionCallingHelper(); - - /** - * Creates a streaming chat response for the given chat conversation. - * - * @param chatRequest The chat completion request. Must have the stream property set to true. - * @return Returns a {@link Flux} stream from chat completion chunks. - */ - public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - - Assert.notNull(chatRequest, "The request body can not be null."); - Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return this.webClient.post() - .uri("/v4/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null), - this.chunkMerger::merge); - return List.of(monoChunk); - }) - .flatMap(mono -> mono); - } - - /** - * ZhiPuAI Embeddings Models: - * Embeddings. - */ - public enum EmbeddingModel { - - /** - * DIMENSION: 1024 - */ - Embedding_2("Embedding-2"); - - public final String value; - - EmbeddingModel(String value) { - this.value = value; - } - - public String getValue() { - return value; - } - } - /** * Represents an embedding vector returned by embedding endpoint. * @@ -765,20 +802,20 @@ public class ZhiPuAiApi { @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof Embedding embedding1)) return false; - return Objects.equals(index, embedding1.index) && Arrays.equals(embedding, embedding1.embedding) && Objects.equals(object, embedding1.object); + return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) && Objects.equals(this.object, embedding1.object); } @Override public int hashCode() { - int result = Objects.hash(index, object); - result = 31 * result + Arrays.hashCode(embedding); + int result = Objects.hash(this.index, this.object); + result = 31 * result + Arrays.hashCode(this.embedding); return result; } @Override public String toString() { return "Embedding{" + - "index=" + index + - ", embedding=" + Arrays.toString(embedding) + - ", object='" + object + '\'' + + "index=" + this.index + + ", embedding=" + Arrays.toString(this.embedding) + + ", object='" + this.object + '\'' + '}'; } } @@ -821,42 +858,5 @@ public class ZhiPuAiApi { @JsonProperty("usage") Usage usage) { } - /** - * Creates an embedding vector representing the input text or token array. - * - * @param embeddingRequest The embedding request. - * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. - * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or - * {@link List} of {@link List} of tokens. For example: - * - *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
- */ - public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { - - Assert.notNull(embeddingRequest, "The request body can not be null."); - - // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single - // request, pass an array of strings or array of token arrays. - Assert.notNull(embeddingRequest.input(), "The input can not be null."); - Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, - "The input must be either a String, or a List of Strings or List of List of integers."); - - if (embeddingRequest.input() instanceof List list) { - Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); - Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); - Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer - || list.get(0) instanceof List, - "The input must be either a String, or a List of Strings or list of list of integers."); - } - - return this.restClient.post() - .uri("/v4/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); - } - } // @formatter:on diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java index 304ec3146..23bfd8404 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; import java.util.List; -import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.retry.RetryUtils; -import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; - /** * ZhiPuAI Image API. * @@ -76,6 +75,17 @@ public class ZhiPuAiImageApi { }).defaultStatusHandler(responseErrorHandler).build(); } + public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { + Assert.notNull(zhiPuAiImageRequest, "Image request cannot be null."); + Assert.hasLength(zhiPuAiImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("/v4/images/generations") + .body(zhiPuAiImageRequest) + .retrieve() + .toEntity(ZhiPuAiImageResponse.class); + } + /** * ZhiPuAI Image API model. * CogView @@ -113,22 +123,11 @@ public class ZhiPuAiImageApi { @JsonProperty("created") Long created, @JsonProperty("data") List data) { } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public record Data( - @JsonProperty("url") String url) { - } // @formatter:onn - public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { - Assert.notNull(zhiPuAiImageRequest, "Image request cannot be null."); - Assert.hasLength(zhiPuAiImageRequest.prompt(), "Prompt cannot be empty."); + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url) { - return this.restClient.post() - .uri("/v4/images/generations") - .body(zhiPuAiImageRequest) - .retrieve() - .toEntity(ZhiPuAiImageResponse.class); } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java index b74303729..e4629e94b 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion.Choice; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionChunk; @@ -27,9 +31,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolC import org.springframework.ai.zhipuai.api.ZhiPuAiApi.LogProbs; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.List; - /** * Helper class to support Streaming function calling. It can merge the streamed * ChatCompletionChunk in case of function calling message. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java index 36d0c4292..52f242771 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.zhipuai.api; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java index dc47c1c7f..88d197e9f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.metadata; import org.springframework.ai.chat.metadata.Usage; @@ -27,10 +28,6 @@ import org.springframework.util.Assert; */ public class ZhiPuAiUsage implements Usage { - public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { - return new ZhiPuAiUsage(usage); - } - private final ZhiPuAiApi.Usage usage; protected ZhiPuAiUsage(ZhiPuAiApi.Usage usage) { @@ -38,6 +35,10 @@ public class ZhiPuAiUsage implements Usage { this.usage = usage; } + public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { + return new ZhiPuAiUsage(usage); + } + protected ZhiPuAiApi.Usage getUsage() { return this.usage; } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java index eb12b04c5..90dac9f57 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.zhipuai.api.MockWeatherService; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java index d35ac839a..00a760cb1 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java index 0d68d135f..c1487282b 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,31 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * @author Geng Rong */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); } /** @@ -65,28 +71,25 @@ public class MockWeatherService implements Function response = zhiPuAiApi + ResponseEntity response = this.zhiPuAiApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, false)); assertThat(response).isNotNull(); @@ -53,7 +55,7 @@ public class ZhiPuAiApiIT { @Test void chatCompletionEntityWithMoreParams() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - ResponseEntity response = zhiPuAiApi + ResponseEntity response = this.zhiPuAiApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 1024, null, false, 0.95, 0.7, null, null, null, "test_request_id", false)); @@ -64,7 +66,7 @@ public class ZhiPuAiApiIT { @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); - Flux response = zhiPuAiApi + Flux response = this.zhiPuAiApi .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, true)); assertThat(response).isNotNull(); @@ -73,7 +75,7 @@ public class ZhiPuAiApiIT { @Test void embeddings() { - ResponseEntity> response = zhiPuAiApi + ResponseEntity> response = this.zhiPuAiApi .embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world")); assertThat(response).isNotNull(); diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java index cf249ae1c..2c6de05af 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,17 @@ package org.springframework.ai.zhipuai.api; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage; @@ -32,10 +37,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest.ToolC import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool.Type; import org.springframework.http.ResponseEntity; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4; @@ -51,6 +52,15 @@ public class ZhiPuAiApiToolFunctionCallIT { ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("null") @Test public void toolFunctionCall() { @@ -92,7 +102,7 @@ public class ZhiPuAiApiToolFunctionCallIT { ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, GLM_4.value, List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion = zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); + ResponseEntity chatCompletion = this.zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); assertThat(chatCompletion.getBody()).isNotNull(); assertThat(chatCompletion.getBody().choices()).isNotEmpty(); @@ -111,7 +121,7 @@ public class ZhiPuAiApiToolFunctionCallIT { MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), MockWeatherService.Request.class); - MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, @@ -122,9 +132,9 @@ public class ZhiPuAiApiToolFunctionCallIT { var functionResponseRequest = new ChatCompletionRequest(messages, GLM_4.value, List.of(functionTool), ToolChoiceBuilder.AUTO); - ResponseEntity chatCompletion2 = zhiPuAiApi.chatCompletionEntity(functionResponseRequest); + ResponseEntity chatCompletion2 = this.zhiPuAiApi.chatCompletionEntity(functionResponseRequest); - logger.info("Final response: " + chatCompletion2.getBody()); + this.logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); @@ -133,13 +143,4 @@ public class ZhiPuAiApiToolFunctionCallIT { .containsAnyOf("30.0°C", "30°C", "30.0°F", "30°F"); } - private static T fromJson(String json, Class targetClass) { - try { - return new ObjectMapper().readValue(json, targetClass); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index aa76c1be8..af2d14750 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.api; +import java.util.List; +import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; @@ -49,10 +55,6 @@ import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -66,25 +68,6 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class ZhiPuAiRetryTests { - private class TestRetryListener implements RetryListener { - - int onErrorRetryCount = 0; - - int onSuccessRetryCount = 0; - - @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - onSuccessRetryCount = context.getRetryCount(); - } - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - onErrorRetryCount = context.getRetryCount(); - } - - } - private TestRetryListener retryListener; private RetryTemplate retryTemplate; @@ -101,14 +84,16 @@ public class ZhiPuAiRetryTests { @BeforeEach public void beforeEach() { - retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - retryListener = new TestRetryListener(); - retryTemplate.registerListener(retryListener); + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); - chatModel = new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), null, retryTemplate); - embeddingModel = new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, - ZhiPuAiEmbeddingOptions.builder().build(), retryTemplate); - imageModel = new ZhiPuAiImageModel(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), retryTemplate); + this.chatModel = new ZhiPuAiChatModel(this.zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), null, + this.retryTemplate); + this.embeddingModel = new ZhiPuAiEmbeddingModel(this.zhiPuAiApi, MetadataMode.EMBED, + ZhiPuAiEmbeddingOptions.builder().build(), this.retryTemplate); + this.imageModel = new ZhiPuAiImageModel(this.zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), + this.retryTemplate); } @Test @@ -119,24 +104,24 @@ public class ZhiPuAiRetryTests { ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, new ZhiPuAiApi.Usage(10, 10, 10)); - when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); - var result = chatModel.call(new Prompt("text")); + var result = this.chatModel.call(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiChatNonTransientError() { - when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @Test @@ -147,24 +132,24 @@ public class ZhiPuAiRetryTests { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, null); - when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(Flux.just(expectedChatCompletion)); - var result = chatModel.stream(new Prompt("text")); + var result = this.chatModel.stream(new Prompt("text")); assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiChatStreamNonTransientError() { - when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block()); + assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @Test @@ -173,24 +158,25 @@ public class ZhiPuAiRetryTests { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10)); - when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - var result = embeddingModel + var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(0); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiEmbeddingNonTransientError() { - when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> embeddingModel + when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -199,25 +185,44 @@ public class ZhiPuAiRetryTests { var expectedResponse = new ZhiPuAiImageResponse(678l, List.of(new Data("url678"))); - when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) .thenThrow(new TransientAiException("Transient Error 1")) .thenThrow(new TransientAiException("Transient Error 2")) .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @Test public void zhiPuAiImageNonTransientError() { - when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) .thenThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, - () -> imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java index 008ffecdb..26d1ec5ad 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ActorsFilms.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; import java.util.List; @@ -30,7 +31,7 @@ public class ActorsFilms { } public String getActor() { - return actor; + return this.actor; } public void setActor(String actor) { @@ -38,7 +39,7 @@ public class ActorsFilms { } public List getMovies() { - return movies; + return this.movies; } public void setMovies(List movies) { @@ -47,7 +48,7 @@ public class ActorsFilms { @Override public String toString() { - return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}'; } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index 45a38e29e..5c0b37374 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -47,16 +59,6 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; - -import java.io.IOException; -import java.net.URL; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -67,14 +69,14 @@ import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") class ZhiPuAiChatModelIT { + private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModelIT.class); + @Autowired protected ChatModel chatModel; @Autowired protected StreamingChatModel streamingChatModel; - private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModelIT.class); - @Value("classpath:/prompts/system-message.st") private Resource systemResource; @@ -82,10 +84,10 @@ class ZhiPuAiChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); @@ -95,10 +97,10 @@ class ZhiPuAiChatModelIT { void streamRoleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - Flux flux = streamingChatModel.stream(prompt); + Flux flux = this.streamingChatModel.stream(prompt); List responses = flux.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -146,7 +148,7 @@ class ZhiPuAiChatModelIT { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -165,14 +167,11 @@ class ZhiPuAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Test void beanOutputConverterRecords() { @@ -185,7 +184,7 @@ class ZhiPuAiChatModelIT { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -207,7 +206,7 @@ class ZhiPuAiChatModelIT { Prompt prompt = new Prompt(promptTemplate.createMessage()); String generationTextFromStream = Objects - .requireNonNull(streamingChatModel.stream(prompt).collectList().block()) + .requireNonNull(this.streamingChatModel.stream(prompt).collectList().block()) .stream() .map(ChatResponse::getResults) .flatMap(List::stream) @@ -238,7 +237,7 @@ class ZhiPuAiChatModelIT { .build())) .build(); - ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -264,7 +263,7 @@ class ZhiPuAiChatModelIT { .build())) .build(); - Flux response = streamingChatModel.stream(new Prompt(messages, promptOptions)); + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); String content = Objects.requireNonNull(response.collectList().block()) .stream() @@ -289,7 +288,7 @@ class ZhiPuAiChatModelIT { var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - var response = chatModel + var response = this.chatModel .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -305,7 +304,7 @@ class ZhiPuAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - ChatResponse response = chatModel + ChatResponse response = this.chatModel .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); @@ -320,7 +319,7 @@ class ZhiPuAiChatModelIT { List.of(new Media(MimeTypeUtils.IMAGE_PNG, new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")))); - Flux response = streamingChatModel.stream(new Prompt(List.of(userMessage), + Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4V.getValue()).build())); String content = Objects.requireNonNull(response.collectList().block()) @@ -335,4 +334,8 @@ class ZhiPuAiChatModelIT { assertThat(content).containsAnyOf("bowl", "basket"); } -} \ No newline at end of file + record ActorsFilmsRecord(String actor, List movies) { + + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java index 162a56b4f..69fad0963 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.chat; +import java.util.List; +import java.util.stream.Collectors; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; @@ -35,10 +41,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -61,7 +63,7 @@ public class ZhiPuAiChatModelObservationIT { @BeforeEach void beforeEach() { - observationRegistry.clear(); + this.observationRegistry.clear(); } @Test @@ -77,7 +79,7 @@ public class ZhiPuAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - ChatResponse chatResponse = chatModel.call(prompt); + ChatResponse chatResponse = this.chatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); @@ -98,7 +100,7 @@ public class ZhiPuAiChatModelObservationIT { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); - Flux chatResponseFlux = chatModel.stream(prompt); + Flux chatResponseFlux = this.chatModel.stream(prompt); List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); @@ -119,7 +121,7 @@ public class ZhiPuAiChatModelObservationIT { } private void validate(ChatResponseMetadata responseMetadata) { - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java index 1371ecde5..447546d60 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.embedding; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingModel; import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,23 +41,23 @@ class EmbeddingIT { @Test void defaultEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } @Test void batchEmbedding() { - assertThat(embeddingModel).isNotNull(); + assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World", "HI")); + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); assertThat(embeddingResponse.getResults()).hasSize(2); @@ -65,7 +67,7 @@ class EmbeddingIT { assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024); - assertThat(embeddingModel.dimensions()).isEqualTo(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java index 04f70b010..9ad910595 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.embedding; +import java.util.List; + import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -35,8 +39,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.retry.support.RetryTemplate; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -64,13 +66,13 @@ public class ZhiPuAiEmbeddingModelObservationIT { EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); - EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); assertThat(embeddingResponse.getResults()).isNotEmpty(); EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) + TestObservationRegistryAssert.assertThat(this.observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) .that() diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java index 618bdf487..474bb499c 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.zhipuai.image; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.Image; import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageOptionsBuilder; @@ -45,7 +47,7 @@ public class ZhiPuAiImageModelIT { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = imageModel.call(imagePrompt); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/mvnw b/mvnw index a16b5431b..657b412d4 100755 --- a/mvnw +++ b/mvnw @@ -1,22 +1,19 @@ #!/bin/sh -# ---------------------------------------------------------------------------- -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- # Maven Start Up Batch script diff --git a/pom.xml b/pom.xml index 463665082..872ad4ed7 100644 --- a/pom.xml +++ b/pom.xml @@ -1,3 +1,19 @@ + + 4.0.0 @@ -225,6 +241,11 @@ 3.3.0 0.0.43 + 3.5.0 + true + true + 9.3 + true @@ -243,6 +264,50 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + ${maven-checkstyle-plugin.version} + + + com.puppycrawl.tools + checkstyle + ${puppycrawl-tools-checkstyle.version} + + + io.spring.javaformat + spring-javaformat-checkstyle + 0.0.43 + + + + + checkstyle-validation + validate + true + + ${disable.checks} + src/checkstyle/checkstyle.xml + src/checkstyle/checkstyle-header.txt + true + + checkstyle.build.directory=${project.build.directory} + checkstyle.suppressions.file=${project.basedir}/src/checkstyle/checkstyle-suppressions.xml + checkstyle.additional.suppressions.file=${project.basedir}/src/checkstyle/checkstyle-suppressions.xml + + true + ${maven-checkstyle-plugin.failsOnError} + + + ${maven-checkstyle-plugin.failOnViolation} + + + + check + + + + org.apache.maven.plugins maven-site-plugin diff --git a/settings.xml b/settings.xml index 890e93070..e86c33787 100644 --- a/settings.xml +++ b/settings.xml @@ -1,3 +1,19 @@ + + + + 4.0.0 diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index 8a25edad0..8322b5f66 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -1,6 +1,23 @@ - + + + 4.0.0 org.springframework.ai @@ -21,6 +38,7 @@ 4.13.1 + false @@ -122,7 +140,7 @@ test - + diff --git a/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java index a3f54a67c..8e48f220a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/ResourceUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import java.io.IOException; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java index 39cd3cd8f..286059bd3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/AiRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.aot; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.aot.hint.TypeReference; -import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider; -import org.springframework.core.type.filter.AnnotationTypeFilter; -import org.springframework.core.type.filter.TypeFilter; +package org.springframework.ai.aot; import java.lang.reflect.Executable; import java.util.Arrays; @@ -31,6 +23,16 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.aot.hint.TypeReference; +import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider; +import org.springframework.core.type.filter.AnnotationTypeFilter; +import org.springframework.core.type.filter.TypeFilter; + /** * Utility methods for creating native runtime hints. See other modules for their * respective native runtime hints. @@ -89,8 +91,9 @@ public abstract class AiRuntimeHints { .stream()// .map(bd -> TypeReference.of(Objects.requireNonNull(bd.getBeanClassName())))// .peek(tr -> { - if (log.isDebugEnabled()) + if (log.isDebugEnabled()) { log.debug("registering [" + tr.getName() + ']'); + } }) .collect(Collectors.toUnmodifiableSet()); } @@ -154,4 +157,4 @@ public abstract class AiRuntimeHints { return jsonTypes; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java index fb676484d..c7fd1dcb8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/KnuddelsRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.springframework.aot.hint.RuntimeHints; @@ -26,4 +27,4 @@ public class KnuddelsRuntimeHints implements RuntimeHintsRegistrar { hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java index 393687bec..2ee283a69 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/SpringAiCoreRuntimeHints.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; -import org.springframework.ai.chat.messages.Message; +import java.lang.reflect.Method; +import java.util.Set; + import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -33,9 +37,6 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.ReflectionUtils; -import java.lang.reflect.Method; -import java.util.Set; - public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar { @Override @@ -56,9 +57,10 @@ public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar { hints.reflection().registerMethod(getName, ExecutableMode.INVOKE); for (var r : Set.of("antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4", - "embedding/embedding-model-dimensions.properties")) + "embedding/embedding-model-dimensions.properties")) { hints.resources().registerResource(new ClassPathResource(r)); + } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java index c6de0ed68..ae89587fb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscription.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; +import java.util.Objects; + import org.springframework.ai.model.ModelResult; import org.springframework.lang.Nullable; -import java.util.Objects; - /** * Represents a response returned by the AI. * @@ -44,7 +45,7 @@ public class AudioTranscription implements ModelResult { @Override public AudioTranscriptionMetadata getMetadata() { - return transcriptionMetadata != null ? transcriptionMetadata : AudioTranscriptionMetadata.NULL; + return this.transcriptionMetadata != null ? this.transcriptionMetadata : AudioTranscriptionMetadata.NULL; } public AudioTranscription withTranscriptionMetadata(@Nullable AudioTranscriptionMetadata transcriptionMetadata) { @@ -54,21 +55,24 @@ public class AudioTranscription implements ModelResult { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AudioTranscription that)) + } + if (!(o instanceof AudioTranscription that)) { return false; - return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata); + } + return Objects.equals(this.text, that.text) + && Objects.equals(this.transcriptionMetadata, that.transcriptionMetadata); } @Override public int hashCode() { - return Objects.hash(text, transcriptionMetadata); + return Objects.hash(this.text, this.transcriptionMetadata); } @Override public String toString() { - return "Transcript{" + "text=" + text + ", transcriptionMetadata=" + transcriptionMetadata + '}'; + return "Transcript{" + "text=" + this.text + ", transcriptionMetadata=" + this.transcriptionMetadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java index bd064a659..5fc1ea106 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ResultMetadata; @@ -32,6 +33,7 @@ public interface AudioTranscriptionMetadata extends ResultMetadata { */ static AudioTranscriptionMetadata create() { return new AudioTranscriptionMetadata() { + }; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java index 95bd877e7..7fec8fa97 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java index 6f5208240..07ca8f644 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionPrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.ModelRequest; @@ -57,12 +58,12 @@ public class AudioTranscriptionPrompt implements ModelRequest { @Override public Resource getInstructions() { - return audioResource; + return this.audioResource; } @Override public AudioTranscriptionOptions getOptions() { - return modelOptions; + return this.modelOptions; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java index e1a652355..6bbe17c51 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; -import org.springframework.ai.model.ModelResponse; - import java.util.List; +import org.springframework.ai.model.ModelResponse; + /** * @author Michael Lavelle * @author Piotr Olaszewski @@ -42,17 +43,17 @@ public class AudioTranscriptionResponse implements ModelResponse getResults() { - return List.of(transcript); + return List.of(this.transcript); } @Override public AudioTranscriptionResponseMetadata getMetadata() { - return transcriptionResponseMetadata; + return this.transcriptionResponseMetadata; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java index 66c3fdf89..7c4d49411 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/audio/transcription/AudioTranscriptionResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.audio.transcription; import org.springframework.ai.model.MutableResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 6f8d2dad6..fd282fd31 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client; import java.net.URL; @@ -21,6 +22,9 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import io.micrometer.observation.ObservationRegistry; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.messages.Message; @@ -36,9 +40,6 @@ import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.util.MimeType; -import io.micrometer.observation.ObservationRegistry; -import reactor.core.publisher.Flux; - /** * Client to perform stateless requests to an AI Model, using a fluent API. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java index b1cdf3cc2..bbad16c28 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClientCustomizer.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index af8fc31c1..c0b2d5a8d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -28,12 +28,18 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -62,12 +68,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.util.StringUtils; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; - /** * The default implementation of {@link ChatClient} as created by the * {@link Builder#build()} } method. @@ -91,6 +91,30 @@ public class DefaultChatClient implements ChatClient { this.defaultChatClientRequest = defaultChatClientRequest; } + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + if (StringUtils.hasText(formatParam)) { + advisorContext.put("formatParam", formatParam); + } + + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, + inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, + inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, + inputRequest.toolContext); + } + + public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, + ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { + + return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), + advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), + advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), + advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), + advisedRequest.advisorParams(), observationRegistry, customObservationConvention, + advisedRequest.toolContext()); + } + @Override public ChatClientRequestSpec prompt() { return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); @@ -145,12 +169,12 @@ public class DefaultChatClient implements ChatClient { public static class DefaultPromptUserSpec implements PromptUserSpec { - private String text = ""; - private final Map params = new HashMap<>(); private final List media = new ArrayList<>(); + private String text = ""; + @Override public PromptUserSpec media(Media... media) { this.media.addAll(Arrays.asList(media)); @@ -220,10 +244,10 @@ public class DefaultChatClient implements ChatClient { public static class DefaultPromptSystemSpec implements PromptSystemSpec { - private String text = ""; - private final Map params = new HashMap<>(); + private String text = ""; + @Override public PromptSystemSpec text(String text) { this.text = text; @@ -296,11 +320,11 @@ public class DefaultChatClient implements ChatClient { } public List getAdvisors() { - return advisors; + return this.advisors; } public Map getParams() { - return params; + return this.params; } } @@ -426,12 +450,12 @@ public class DefaultChatClient implements ChatClient { var initialAdvisedRequest = toAdvisedRequest(inputRequest, ""); - // @formatter:off + // @formatter:off // Apply the around advisor chain that terminates with the, last, // model call advisor. - Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); + Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); - return stream + return stream .map(AdvisedResponse::response) .doOnError(observation::error) .doFinally(s -> observation.stop()) @@ -464,12 +488,6 @@ public class DefaultChatClient implements ChatClient { private final ChatModel chatModel; - private String userText = ""; - - private String systemText = ""; - - private ChatOptions chatOptions; - private final List media = new ArrayList<>(); private final List functionNames = new ArrayList<>(); @@ -490,6 +508,90 @@ public class DefaultChatClient implements ChatClient { private final Map toolContext = new HashMap<>(); + private String userText = ""; + + private String systemText = ""; + + private ChatOptions chatOptions; + + /* copy constructor */ + DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { + this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, + ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, + ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); + } + + public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, + String systemText, Map systemParams, List functionCallbacks, + List messages, List functionNames, List media, ChatOptions chatOptions, + List advisors, Map advisorParams, ObservationRegistry observationRegistry, + ChatClientObservationConvention customObservationConvention, Map toolContext) { + + this.chatModel = chatModel; + this.chatOptions = chatOptions != null ? chatOptions.copy() + : (chatModel.getDefaultOptions() != null) ? chatModel.getDefaultOptions().copy() : null; + + this.userText = userText; + this.userParams.putAll(userParams); + this.systemText = systemText; + this.systemParams.putAll(systemParams); + + this.functionNames.addAll(functionNames); + this.functionCallbacks.addAll(functionCallbacks); + this.messages.addAll(messages); + this.media.addAll(media); + this.advisors.addAll(advisors); + this.advisorParams.putAll(advisorParams); + this.observationRegistry = observationRegistry; + this.customObservationConvention = customObservationConvention; + this.toolContext.putAll(toolContext); + + // @formatter:off + // At the stack bottom add the non-streaming and streaming model call advisors. + // They play the role of the last advisor in the around advisor chain. + this.advisors.add(new CallAroundAdvisor() { + + @Override + public String getName() { + return CallAroundAdvisor.class.getSimpleName(); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext())); + } + }); + + this.advisors.add(new StreamAroundAdvisor() { + + @Override + public String getName() { + return StreamAroundAdvisor.class.getSimpleName(); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + return chatModel.stream(advisedRequest.toPrompt()) + .map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) + .publishOn(Schedulers.boundedElastic()); // TODO add option to disable. + } + }); + // @formatter:on + + this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) + .pushAll(this.advisors); + } + private ObservationRegistry getObservationRegistry() { return this.observationRegistry; } @@ -546,91 +648,13 @@ public class DefaultChatClient implements ChatClient { return this.toolContext; } - /* copy constructor */ - DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, - ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); - } - - public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, - String systemText, Map systemParams, List functionCallbacks, - List messages, List functionNames, List media, ChatOptions chatOptions, - List advisors, Map advisorParams, ObservationRegistry observationRegistry, - ChatClientObservationConvention customObservationConvention, Map toolContext) { - - this.chatModel = chatModel; - this.chatOptions = chatOptions != null ? chatOptions.copy() - : (chatModel.getDefaultOptions() != null) ? chatModel.getDefaultOptions().copy() : null; - - this.userText = userText; - this.userParams.putAll(userParams); - this.systemText = systemText; - this.systemParams.putAll(systemParams); - - this.functionNames.addAll(functionNames); - this.functionCallbacks.addAll(functionCallbacks); - this.messages.addAll(messages); - this.media.addAll(media); - this.advisors.addAll(advisors); - this.advisorParams.putAll(advisorParams); - this.observationRegistry = observationRegistry; - this.customObservationConvention = customObservationConvention; - this.toolContext.putAll(toolContext); - - // @formatter:off - // At the stack bottom add the non-streaming and streaming model call advisors. - // They play the role of the last advisor in the around advisor chain. - this.advisors.add(new CallAroundAdvisor() { - - @Override - public String getName() { - return CallAroundAdvisor.class.getSimpleName(); - } - - @Override - public int getOrder() { - return Ordered.LOWEST_PRECEDENCE; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext())); - } - }); - - this.advisors.add(new StreamAroundAdvisor() { - - @Override - public String getName() { - return StreamAroundAdvisor.class.getSimpleName(); - } - - @Override - public int getOrder() { - return Ordered.LOWEST_PRECEDENCE; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - return chatModel.stream(advisedRequest.toPrompt()) - .map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) - .publishOn(Schedulers.boundedElastic());// TODO add option to disable. - } - }); - // @formatter:on - - this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) - .pushAll(this.advisors); - } - /** * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. */ public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient - .builder(chatModel, this.observationRegistry, this.customObservationConvention) + .builder(this.chatModel, this.observationRegistry, this.customObservationConvention) .defaultSystem(s -> s.text(this.systemText).params(this.systemParams)) .defaultUser(u -> u.text(this.userText) .params(this.userParams) @@ -827,30 +851,6 @@ public class DefaultChatClient implements ChatClient { } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { - Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - if (StringUtils.hasText(formatParam)) { - advisorContext.put("formatParam", formatParam); - } - - return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, - inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, - inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, - inputRequest.toolContext); - } - - public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, - ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { - - return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), - advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), - advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), - advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), - advisedRequest.advisorParams(), observationRegistry, customObservationConvention, - advisedRequest.toolContext()); - } - // Prompt public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec { @@ -877,7 +877,7 @@ public class DefaultChatClient implements ChatClient { } private ChatResponse doGetChatResponse(Prompt prompt) { - return chatModel.call(prompt); + return this.chatModel.call(prompt); } } @@ -913,4 +913,4 @@ public class DefaultChatClient implements ChatClient { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 1053a5d02..6b03d8e2a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.client.ChatClient.Builder; import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec; import org.springframework.ai.chat.client.ChatClient.PromptUserSpec; @@ -35,8 +37,6 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.core.io.Resource; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - /** * DefaultChatClientBuilder is a builder class for creating a ChatClient. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index bb6e0ae54..fa1d1526a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,18 +21,18 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; - /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and * {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a @@ -90,4 +90,4 @@ public interface RequestResponseAdvisor extends CallAroundAdvisor, StreamAroundA .map(chatResponse -> new AdvisedResponse(chatResponse, context)); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java index 069f46aa6..b6ab8fedd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,4 +36,5 @@ public record ResponseEntity(R response, E entity) { public E getEntity() { return this.entity; } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 45f4bf8a7..304e53865 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,10 @@ package org.springframework.ai.chat.client.advisor; import java.util.Map; import java.util.function.Function; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -26,11 +30,6 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.util.Assert; -import org.stringtemplate.v4.compiler.CodeGenerator.includeExpr_return; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; /** * Abstract class that serves as a base for chat memory advisors. @@ -129,7 +128,7 @@ public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, } public static abstract class AbstractBuilder { - + protected String conversationId = DEFAULT_CHAT_MEMORY_CONVERSATION_ID; protected int chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index fe5e3b4ea..ee441d1d6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor; import java.util.ArrayList; @@ -20,6 +21,10 @@ import java.util.Deque; import java.util.List; import java.util.concurrent.ConcurrentLinkedDeque; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -35,10 +40,6 @@ import org.springframework.core.OrderComparator; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import reactor.core.publisher.Flux; - /** * Implementation of the {@link CallAroundAdvisorChain} and * {@link StreamAroundAdvisorChain}. Used by the @@ -71,6 +72,10 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream this.streamAroundAdvisors = streamAroundAdvisors; } + public static Builder builder(ObservationRegistry observationRegistry) { + return new Builder(observationRegistry); + } + @Override public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { @@ -117,17 +122,13 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream // @formatter:off return Flux.defer(() -> advisor.aroundStream(advisedRequest, this)) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on }); } - public static Builder builder(ObservationRegistry observationRegistry) { - return new Builder(observationRegistry); - } - public static class Builder { private final ObservationRegistry observationRegistry; @@ -195,4 +196,4 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java index 2c20da81e..8534a2750 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/LastMaxTokenSizeContentPurger.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,13 @@ package org.springframework.ai.chat.client.advisor; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.model.Content; import org.springframework.ai.model.MediaContent; import org.springframework.ai.tokenizer.TokenCountEstimator; -import java.util.ArrayList; -import java.util.List; - /** * Returns a new list of content (e.g list of messages of list of documents) that is a * subset of the input list of contents and complies with the max token size constraint. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 0677e28ad..6aa2ca1bc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ package org.springframework.ai.chat.client.advisor; import java.util.ArrayList; import java.util.List; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -29,8 +31,6 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.MessageAggregator; -import reactor.core.publisher.Flux; - /** * Memory is retrieved added as a collection of messages to the prompt * @@ -52,6 +52,10 @@ public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor { protected Builder(ChatMemory chatMemory) { @@ -124,4 +124,4 @@ public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor { private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; @@ -164,4 +164,4 @@ public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor onFinishReason() { - return (advisedResponse) -> advisedResponse.response() + return advisedResponse -> advisedResponse.response() .getResults() .stream() .filter(result -> result != null && result.getMetadata() != null @@ -262,11 +266,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv .isPresent(); } - public static Builder builder(VectorStore vectorStore) { - return new Builder(vectorStore); - } - - public static class Builder { + public static final class Builder { private final VectorStore vectorStore; @@ -312,4 +312,4 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index 0706320b8..054e3fbf0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor; import java.util.List; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -29,8 +32,6 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; - /** * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the * response if the user input contains any of the sensitive words. @@ -62,6 +63,10 @@ public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor this.order = order; } + public static Builder builder() { + return new Builder(); + } + public String getName() { return this.getClass().getSimpleName(); } @@ -82,7 +87,7 @@ public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) - && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { return Flux.just(createFailureResponse(advisedRequest)); } @@ -100,11 +105,7 @@ public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor return this.order; } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { + public static final class Builder { private List sensitiveWords; @@ -136,4 +137,4 @@ public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index d23388682..889753108 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,24 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; -import reactor.core.publisher.Flux; - /** * A simple logger advisor that logs the request and response messages. * @@ -38,22 +39,19 @@ import reactor.core.publisher.Flux; */ public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + public static final Function DEFAULT_REQUEST_TO_STRING = request -> request.toString(); + + public static final Function DEFAULT_RESPONSE_TO_STRING = response -> ModelOptionsUtils + .toJsonString(response); + private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); - private int order; - - public static final Function DEFAULT_REQUEST_TO_STRING = (request) -> { - return request.toString(); - }; - - public static final Function DEFAULT_RESPONSE_TO_STRING = (response) -> { - return ModelOptionsUtils.toJsonString(response); - }; - private final Function requestToString; private final Function responseToString; + private int order; + public SimpleLoggerAdvisor() { this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index c844f5dfe..02acf5212 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -36,8 +38,6 @@ import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; -import reactor.core.publisher.Flux; - /** * Memory is retrieved from a VectorStore added into the prompt's system text. * @@ -99,6 +99,10 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor { private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 032de8b63..afca77476 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.function.Function; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -32,6 +31,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.util.CollectionUtils; @@ -63,12 +63,6 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String system Map userParams, Map systemParams, List advisors, Map advisorParams, Map adviseContext, Map toolContext) { - public AdvisedRequest updateContext(Function, Map> contextTransform) { - return from(this) - .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) - .build(); - } - public static Builder from(AdvisedRequest from) { Builder builder = new Builder(); builder.chatModel = from.chatModel; @@ -93,8 +87,60 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String system return new Builder(); } + public AdvisedRequest updateContext(Function, Map> contextTransform) { + return from(this) + .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) + .build(); + } + + public Prompt toPrompt() { + + var messages = new ArrayList(this.messages()); + + String processedSystemText = this.systemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(this.systemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + String formatParam = (String) this.adviseContext().get("formatParam"); + + var processedUserText = StringUtils.hasText(formatParam) + ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(this.userParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, this.media())); + } + + if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!this.functionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); + } + if (!this.functionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); + } + if (!CollectionUtils.isEmpty(this.toolContext())) { + functionCallingOptions.setToolContext(this.toolContext()); + } + } + + return new Prompt(messages, this.chatOptions()); + } + public static class Builder { + public Map toolContext = Map.of(); + private ChatModel chatModel; private String userText = ""; @@ -121,8 +167,6 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String system private Map adviseContext = Map.of(); - public Map toolContext = Map.of(); - public Builder withChatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; @@ -194,55 +238,11 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String system } public AdvisedRequest build() { - return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media, + return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, this.advisors, this.advisorParams, this.adviseContext, this.toolContext); } } - public Prompt toPrompt() { - - var messages = new ArrayList(this.messages()); - - String processedSystemText = this.systemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(this.systemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - String formatParam = (String) this.adviseContext().get("formatParam"); - - var processedUserText = StringUtils.hasText(formatParam) - ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); - - if (StringUtils.hasText(processedUserText)) { - - Map userParams = new HashMap<>(this.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); - } - messages.add(new UserMessage(processedUserText, this.media())); - } - - if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!this.functionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); - } - if (!this.functionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); - } - if (!CollectionUtils.isEmpty(this.toolContext())) { - functionCallingOptions.setToolContext(this.toolContext()); - } - } - - return new Prompt(messages, this.chatOptions()); - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java index 8c81740cf..a03247fd6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import java.util.Collections; @@ -29,15 +30,15 @@ import org.springframework.util.Assert; */ public record AdvisedResponse(ChatResponse response, Map adviseContext) { - public AdvisedResponse updateContext(Function, Map> contextTransform) { - return new AdvisedResponse(response, - Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(adviseContext)))); - } - public static Builder builder() { return new Builder(); } + public AdvisedResponse updateContext(Function, Map> contextTransform) { + return new AdvisedResponse(this.response, + Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))); + } + public static class Builder { private ChatResponse response; @@ -66,8 +67,9 @@ public record AdvisedResponse(ChatResponse response, Map adviseC public AdvisedResponse build() { Assert.notNull(this.adviseContext, "the adviseContext must be non-null"); - return new AdvisedResponse(response, adviseContext); + return new AdvisedResponse(this.response, this.adviseContext); } } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java index c7a931b85..c03eb507e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import org.springframework.core.Ordered; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java index 57d19df60..05369aace 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; /** @@ -31,4 +32,4 @@ public interface CallAroundAdvisor extends Advisor { */ AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java index 9a51a01fa..9158a7212 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; /** @@ -34,4 +35,4 @@ public interface CallAroundAdvisorChain { */ AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java index eeb65aa66..56ff624f0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import reactor.core.publisher.Flux; @@ -32,4 +33,4 @@ public interface StreamAroundAdvisor extends Advisor { */ Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java index 43837fd43..175ae9e71 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.api; import reactor.core.publisher.Flux; @@ -36,4 +37,4 @@ public interface StreamAroundAdvisorChain { */ Flux nextAroundStream(AdvisedRequest advisedRequest); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index 9effbf29b..394e89efd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -1,30 +1,31 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; import java.util.Map; +import io.micrometer.observation.Observation; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import io.micrometer.observation.Observation; - /** * Context used to store metadata for chat client advisors. * @@ -34,16 +35,15 @@ import io.micrometer.observation.Observation; */ public class AdvisorObservationContext extends Observation.Context { - public enum Type { - - BEFORE, AFTER, AROUND - - } - private final String advisorName; private final Type advisorType; + /** + * The order of the advisor in the advisor chain. + */ + private final int order; + /** * The {@link AdvisedRequest} data to be advised. Represents the row * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. @@ -65,11 +65,6 @@ public class AdvisorObservationContext extends Observation.Context { @Nullable private Map advisorResponseContext; - /** - * The order of the advisor in the advisor chain. - */ - private final int order; - public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest, @Nullable Map advisorRequestContext, @Nullable Map advisorResponseContext, int order) { @@ -84,6 +79,10 @@ public class AdvisorObservationContext extends Observation.Context { this.order = order; } + public static Builder builder() { + return new Builder(); + } + public String getAdvisorName() { return this.advisorName; } @@ -123,8 +122,10 @@ public class AdvisorObservationContext extends Observation.Context { return this.order; } - public static Builder builder() { - return new Builder(); + public enum Type { + + BEFORE, AFTER, AROUND + } public static class Builder { @@ -172,10 +173,10 @@ public class AdvisorObservationContext extends Observation.Context { } public AdvisorObservationContext build() { - return new AdvisorObservationContext(advisorName, advisorType, advisorRequest, advisorRequestContext, - advisorResponseContext, order); + return new AdvisorObservationContext(this.advisorName, this.advisorType, this.advisorRequest, + this.advisorRequestContext, this.advisorResponseContext, this.order); } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java index 7726c65e9..10e5212c1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationConvention.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.client.advisor.observation; import io.micrometer.observation.Observation; @@ -32,4 +33,4 @@ public interface AdvisorObservationConvention extends ObservationConvention toolCalls; public AssistantMessage(String content) { @@ -63,24 +61,31 @@ public class AssistantMessage extends AbstractMessage { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof AssistantMessage that)) + } + if (!(o instanceof AssistantMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(toolCalls, that.toolCalls); + } + return Objects.equals(this.toolCalls, that.toolCalls); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), toolCalls); + return Objects.hash(super.hashCode(), this.toolCalls); } @Override public String toString() { - return "AssistantMessage [messageType=" + messageType + ", toolCalls=" + toolCalls + ", textContent=" - + textContent + ", metadata=" + metadata + "]"; + return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent=" + + this.textContent + ", metadata=" + this.metadata + "]"; + } + + public record ToolCall(String id, String type, String name, String arguments) { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java index bdd11e9a6..089b88b8a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import org.springframework.ai.model.Content; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java index 7603ab39f..876b004ee 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; /** @@ -56,10 +57,6 @@ public enum MessageType { this.value = value; } - public String getValue() { - return value; - } - public static MessageType fromValue(String value) { for (MessageType messageType : MessageType.values()) { if (messageType.getValue().equals(value)) { @@ -69,4 +66,8 @@ public enum MessageType { throw new IllegalArgumentException("Invalid MessageType value: " + value); } + public String getValue() { + return this.value; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index ddcff7966..e673de98a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.Map; @@ -44,24 +45,27 @@ public class SystemMessage extends AbstractMessage { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof SystemMessage that)) + } + if (!(o instanceof SystemMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(textContent, that.textContent); + } + return Objects.equals(this.textContent, that.textContent); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), textContent); + return Objects.hash(super.hashCode(), this.textContent); } @Override public String toString() { - return "SystemMessage{" + "textContent='" + textContent + '\'' + ", messageType=" + messageType + ", metadata=" - + metadata + '}'; + return "SystemMessage{" + "textContent='" + this.textContent + '\'' + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 42f91f9df..47da25218 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.List; @@ -28,9 +29,6 @@ import java.util.Objects; */ public class ToolResponseMessage extends AbstractMessage { - public record ToolResponse(String id, String name, String responseData) { - }; - protected final List responses; public ToolResponseMessage(List responses) { @@ -48,24 +46,31 @@ public class ToolResponseMessage extends AbstractMessage { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ToolResponseMessage that)) + } + if (!(o instanceof ToolResponseMessage that)) { return false; - if (!super.equals(o)) + } + if (!super.equals(o)) { return false; - return Objects.equals(responses, that.responses); + } + return Objects.equals(this.responses, that.responses); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), responses); + return Objects.hash(super.hashCode(), this.responses); } @Override public String toString() { - return "ToolResponseMessage{" + "responses=" + responses + ", messageType=" + messageType + ", metadata=" - + metadata + '}'; + return "ToolResponseMessage{" + "responses=" + this.responses + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + '}'; + } + + public record ToolResponse(String id, String name, String responseData) { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 53c324257..5a7e7db57 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.messages; import java.util.ArrayList; @@ -69,8 +70,8 @@ public class UserMessage extends AbstractMessage implements MediaContent { @Override public String toString() { - return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + metadata + ", messageType=" - + messageType + '}'; + return "UserMessage{" + "content='" + getContent() + '\'' + ", properties=" + this.metadata + ", messageType=" + + this.messageType + '}'; } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java index 744c3fdab..77728e276 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import org.springframework.ai.model.ResultMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java index f58bc5d24..20126472b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.util.Map; @@ -20,6 +21,7 @@ import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.AbstractResponseMetadata; import org.springframework.ai.model.ResponseMetadata; @@ -36,7 +38,8 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re private final static Logger logger = LoggerFactory.getLogger(ChatResponseMetadata.class); private String id = ""; // Set to blank to preserve backward compat with previous - // interface default methods + + // interface default methods private String model = ""; @@ -46,6 +49,10 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re private PromptMetadata promptMetadata = PromptMetadata.empty(); + public static Builder builder() { + return new Builder(); + } + /** * A unique identifier for the chat completion operation. * @return unique operation identifier. @@ -88,6 +95,29 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re return this.promptMetadata; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ChatResponseMetadata that)) { + return false; + } + return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) + && Objects.equals(this.rateLimit, that.rateLimit) && Objects.equals(this.usage, that.usage) + && Objects.equals(this.promptMetadata, that.promptMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(this.id, this.model, this.rateLimit, this.usage, this.promptMetadata); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit()); + } + public static class Builder { private final ChatResponseMetadata chatResponseMetadata; @@ -145,29 +175,4 @@ public class ChatResponseMetadata extends AbstractResponseMetadata implements Re } - public static Builder builder() { - return new Builder(); - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof ChatResponseMetadata that)) - return false; - return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) - && Objects.equals(this.rateLimit, that.rateLimit) && Objects.equals(this.usage, that.usage) - && Objects.equals(this.promptMetadata, that.promptMetadata); - } - - @Override - public int hashCode() { - return Objects.hash(this.id, this.model, this.rateLimit, this.usage, this.promptMetadata); - } - - @Override - public String toString() { - return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java index 5f5ee9c51..a9fa52a30 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; +import java.util.Objects; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.Objects; - /** * Default implementation of the {@link Usage} interface. * @@ -66,19 +67,19 @@ public class DefaultUsage implements Usage { @Override @JsonProperty("promptTokens") public Long getPromptTokens() { - return promptTokens; + return this.promptTokens; } @Override @JsonProperty("generationTokens") public Long getGenerationTokens() { - return generationTokens; + return this.generationTokens; } @Override @JsonProperty("totalTokens") public Long getTotalTokens() { - return totalTokens; + return this.totalTokens; } private Long calculateTotalTokens(Long promptTokens, Long generationTokens) { @@ -87,25 +88,27 @@ public class DefaultUsage implements Usage { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } DefaultUsage that = (DefaultUsage) o; - return Objects.equals(promptTokens, that.promptTokens) - && Objects.equals(generationTokens, that.generationTokens) - && Objects.equals(totalTokens, that.totalTokens); + return Objects.equals(this.promptTokens, that.promptTokens) + && Objects.equals(this.generationTokens, that.generationTokens) + && Objects.equals(this.totalTokens, that.totalTokens); } @Override public int hashCode() { - return Objects.hash(promptTokens, generationTokens, totalTokens); + return Objects.hash(this.promptTokens, this.generationTokens, this.totalTokens); } @Override public String toString() { - return "DefaultUsage{" + "promptTokens=" + promptTokens + ", generationTokens=" + generationTokens - + ", totalTokens=" + totalTokens + '}'; + return "DefaultUsage{" + "promptTokens=" + this.promptTokens + ", generationTokens=" + this.generationTokens + + ", totalTokens=" + this.totalTokens + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java index 80eb2462c..0506dbbf2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyRateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java index 48cf590e7..b9cdaf872 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java index bbc32a791..c78e61c27 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java index 689380219..9cedeecd7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index d2dffc808..887bfbaa4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index e32b4da6d..38ba886f6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.ArrayList; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java index 5f2687bca..e72a1c23e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.model; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; +package org.springframework.ai.chat.model; import java.util.Arrays; @@ -24,6 +22,8 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Model; public interface ChatModel extends Model, StreamingChatModel { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index 58d75e218..657a4ef9f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.List; @@ -20,9 +21,9 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.model.ModelResponse; import org.springframework.util.CollectionUtils; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; /** * The chat completion (e.g. generation) response returned by an AI provider. @@ -57,6 +58,10 @@ public class ChatResponse implements ModelResponse { this.generations = List.copyOf(generations); } + public static ChatResponse.Builder builder() { + return new ChatResponse.Builder(); + } + /** * The {@link List} of {@link Generation generated outputs}. *

@@ -91,29 +96,27 @@ public class ChatResponse implements ModelResponse { @Override public String toString() { - return "ChatResponse [metadata=" + chatResponseMetadata + ", generations=" + generations + "]"; + return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]"; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ChatResponse that)) + } + if (!(o instanceof ChatResponse that)) { return false; - return Objects.equals(chatResponseMetadata, that.chatResponseMetadata) - && Objects.equals(generations, that.generations); + } + return Objects.equals(this.chatResponseMetadata, that.chatResponseMetadata) + && Objects.equals(this.generations, that.generations); } @Override public int hashCode() { - return Objects.hash(chatResponseMetadata, generations); + return Objects.hash(this.chatResponseMetadata, this.generations); } - public static ChatResponse.Builder builder() { - return new ChatResponse.Builder(); - } - - public static class Builder { + public static final class Builder { private List generations; @@ -149,7 +152,7 @@ public class ChatResponse implements ModelResponse { } public ChatResponse build() { - return new ChatResponse(generations, chatResponseMetadataBuilder.build()); + return new ChatResponse(this.generations, this.chatResponseMetadataBuilder.build()); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java index 9d98d15cd..210935ead 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/Generation.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.Map; @@ -82,10 +83,12 @@ public class Generation implements ModelResult { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Generation that)) + } + if (!(o instanceof Generation that)) { return false; + } return Objects.equals(this.assistantMessage, that.assistantMessage) && Objects.equals(this.chatGenerationMetadata, that.chatGenerationMetadata); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 1c6bfc702..6940d193e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; @@ -34,8 +36,6 @@ import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - /** * Helper that for streaming chat responses, aggregate the chat response messages into a * single AssistantMessage. Job is performed in parallel to the chat response processing. @@ -167,9 +167,7 @@ public class MessageAggregator { metadataPromptMetadataRef.set(PromptMetadata.empty()); metadataRateLimitRef.set(new EmptyRateLimit()); - }).doOnError(e -> { - logger.error("Aggregation Error", e); - }); + }).doOnError(e -> logger.error("Aggregation Error", e)); } public record DefaultUsage(long promptTokens, long generationTokens, long totalTokens) implements Usage { @@ -188,6 +186,7 @@ public class MessageAggregator { public Long getTotalTokens() { return totalTokens(); } + } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java index 2eab40e45..9105add8c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/StreamingChatModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.model; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java index 2d49e1ebc..5ba3a60eb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.model; import java.util.Collections; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java index 72cb7d825..c68227bb5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; + import org.springframework.ai.observation.tracing.TracingHelper; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java index 9404c2556..59612fa26 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.observation.Observation; @@ -21,6 +22,7 @@ import io.micrometer.tracing.handler.TracingObservationHandler; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; + import org.springframework.ai.observation.conventions.AiObservationAttributes; import org.springframework.ai.observation.conventions.AiObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java index 1604e0451..9b19d4199 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; + import org.springframework.ai.model.observation.ModelUsageMetricsGenerator; /** @@ -38,7 +40,8 @@ public class ChatModelMeterObservationHandler implements ObservationHandler prompt(ChatModelObservationContext context) { if (CollectionUtils.isEmpty(context.getRequest().getInstructions())) { return List.of(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java index eb20f161a..525f5fab3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import org.springframework.ai.chat.model.ChatResponse; @@ -40,15 +41,15 @@ public class ChatModelObservationContext extends ModelObservationContext stop) { + this.options.setStopSequences(stop); + return this; + } + + public ChatOptionsBuilder withTemperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public ChatOptionsBuilder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public ChatOptionsBuilder withTopP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public ChatOptions build() { + return this.options; + } private static class DefaultChatOptions implements ChatOptions { @@ -39,7 +93,7 @@ public class ChatOptionsBuilder { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -48,7 +102,7 @@ public class ChatOptionsBuilder { @Override public Double getFrequencyPenalty() { - return frequencyPenalty; + return this.frequencyPenalty; } public void setFrequencyPenalty(Double frequencyPenalty) { @@ -57,7 +111,7 @@ public class ChatOptionsBuilder { @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } public void setMaxTokens(Integer maxTokens) { @@ -66,7 +120,7 @@ public class ChatOptionsBuilder { @Override public Double getPresencePenalty() { - return presencePenalty; + return this.presencePenalty; } public void setPresencePenalty(Double presencePenalty) { @@ -75,7 +129,7 @@ public class ChatOptionsBuilder { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -84,7 +138,7 @@ public class ChatOptionsBuilder { @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -93,7 +147,7 @@ public class ChatOptionsBuilder { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -102,7 +156,7 @@ public class ChatOptionsBuilder { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -124,57 +178,4 @@ public class ChatOptionsBuilder { } - private final DefaultChatOptions options = new DefaultChatOptions(); - - private ChatOptionsBuilder() { - } - - public static ChatOptionsBuilder builder() { - return new ChatOptionsBuilder(); - } - - public ChatOptionsBuilder withModel(String model) { - options.setModel(model); - return this; - } - - public ChatOptionsBuilder withFrequencyPenalty(Double frequencyPenalty) { - options.setFrequencyPenalty(frequencyPenalty); - return this; - } - - public ChatOptionsBuilder withMaxTokens(Integer maxTokens) { - options.setMaxTokens(maxTokens); - return this; - } - - public ChatOptionsBuilder withPresencePenalty(Double presencePenalty) { - options.setPresencePenalty(presencePenalty); - return this; - } - - public ChatOptionsBuilder withStopSequences(List stop) { - options.setStopSequences(stop); - return this; - } - - public ChatOptionsBuilder withTemperature(Double temperature) { - options.setTemperature(temperature); - return this; - } - - public ChatOptionsBuilder withTopK(Integer topK) { - options.setTopK(topK); - return this; - } - - public ChatOptionsBuilder withTopP(Double topP) { - options.setTopP(topP); - return this; - } - - public ChatOptions build() { - return options; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java index 4db4aee57..918316187 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; + /** * A PromptTemplate that lets you specify the role as a string should the current * implementations and their roles not suffice for your needs. @@ -36,7 +37,7 @@ public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplate @Override public String render() { StringBuilder sb = new StringBuilder(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { sb.append(promptTemplate.render()); } return sb.toString(); @@ -45,7 +46,7 @@ public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplate @Override public String render(Map model) { StringBuilder sb = new StringBuilder(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { sb.append(promptTemplate.render(model)); } return sb.toString(); @@ -54,7 +55,7 @@ public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplate @Override public List createMessages() { List messages = new ArrayList<>(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { messages.add(promptTemplate.createMessage()); } return messages; @@ -63,7 +64,7 @@ public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplate @Override public List createMessages(Map model) { List messages = new ArrayList<>(); - for (PromptTemplate promptTemplate : promptTemplates) { + for (PromptTemplate promptTemplate : this.promptTemplates) { messages.add(promptTemplate.createMessage(model)); } return messages; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java index 913c18c85..3a8b36889 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; public class FunctionPromptTemplate extends PromptTemplate { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 3d36b1bfb..743314f8d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.ArrayList; @@ -91,10 +92,12 @@ public class Prompt implements ModelRequest> { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Prompt prompt)) + } + if (!(o instanceof Prompt prompt)) { return false; + } return Objects.equals(this.messages, prompt.messages) && Objects.equals(this.chatOptions, prompt.chatOptions); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java index 162b0d59f..852089e2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.io.IOException; @@ -30,22 +31,22 @@ import org.antlr.runtime.TokenStream; import org.stringtemplate.v4.ST; import org.stringtemplate.v4.compiler.STLexer; -import org.springframework.ai.model.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.model.Media; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions { - private ST st; - - private Map dynamicModel = new HashMap<>(); - protected String template; protected TemplateFormat templateFormat = TemplateFormat.ST; + private ST st; + + private Map dynamicModel = new HashMap<>(); + public PromptTemplate(Resource resource) { try (InputStream inputStream = resource.getInputStream()) { this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); @@ -122,7 +123,7 @@ public class PromptTemplate implements PromptTemplateActions, PromptTemplateMess @Override public String render() { validate(this.dynamicModel); - return st.render(); + return this.st.render(); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java index 872d35f4c..76d5ef017 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java index dd4424d07..120cd87aa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; + public interface PromptTemplateChatActions { List createMessages(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java index 8edcd36da..c87507b3c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.prompt; -import org.springframework.ai.model.Media; -import org.springframework.ai.chat.messages.Message; +package org.springframework.ai.chat.prompt; import java.util.List; import java.util.Map; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.model.Media; + public interface PromptTemplateMessageActions { Message createMessage(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java index bd81ed4dd..81be88b93 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java index 8ac1aa85e..18b1629db 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; +import java.util.Map; + import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.core.io.Resource; -import java.util.Map; - public class SystemPromptTemplate extends PromptTemplate { public SystemPromptTemplate(String template) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java index fe13fcf7d..a174300e9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.prompt; public enum TemplateFormat { @@ -25,10 +26,6 @@ public enum TemplateFormat { this.value = value; } - public String getValue() { - return value; - } - public static TemplateFormat fromValue(String value) { for (TemplateFormat templateFormat : TemplateFormat.values()) { if (templateFormat.getValue().equals(value)) { @@ -38,4 +35,8 @@ public enum TemplateFormat { throw new IllegalArgumentException("Invalid TemplateFormat value: " + value); } + public String getValue() { + return this.value; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java index 6f209fa86..b4e4d0868 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.core.convert.support.DefaultConversionService; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java index 05077025c..7a22dfc55 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/AbstractMessageOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.messaging.converter.MessageConverter; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java index 424d291de..68b0a3582 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.converter; -import static com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON; -import static com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12; +package org.springframework.ai.converter; import java.lang.reflect.Type; import java.util.Objects; -import com.fasterxml.jackson.databind.json.JsonMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.util.JacksonUtils; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.lang.NonNull; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.util.DefaultIndenter; @@ -37,12 +27,19 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.json.JsonMapper; import com.github.victools.jsonschema.generator.Option; import com.github.victools.jsonschema.generator.SchemaGenerator; import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.util.JacksonUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.lang.NonNull; /** * An implementation of {@link StructuredOutputConverter} that transforms the LLM output @@ -62,9 +59,6 @@ public class BeanOutputConverter implements StructuredOutputConverter { private final Logger logger = LoggerFactory.getLogger(BeanOutputConverter.class); - /** Holds the generated JSON schema for the target type. */ - private String jsonSchema; - /** * The target class type reference to which the output will be converted. */ @@ -73,6 +67,9 @@ public class BeanOutputConverter implements StructuredOutputConverter { /** The object mapper used for deserialization and other JSON operations. */ private final ObjectMapper objectMapper; + /** Holds the generated JSON schema for the target type. */ + private String jsonSchema; + /** * Constructor to initialize with the target type's class. * @param clazz The target type's class. @@ -110,21 +107,6 @@ public class BeanOutputConverter implements StructuredOutputConverter { this(new CustomizedTypeReference<>(typeRef), objectMapper); } - private static class CustomizedTypeReference extends TypeReference { - - private final Type type; - - CustomizedTypeReference(ParameterizedTypeReference typeRef) { - this.type = typeRef.getType(); - } - - @Override - public Type getType() { - return this.type; - } - - } - /** * Constructor to initialize with the target class type reference, a custom object * mapper, and a line endings normalizer to ensure consistent line endings on any @@ -144,7 +126,9 @@ public class BeanOutputConverter implements StructuredOutputConverter { */ private void generateSchema() { JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); - SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON) + SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder( + com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, + com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON) .with(jacksonModule) .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT); SchemaGeneratorConfig config = configBuilder.build(); @@ -234,4 +218,19 @@ public class BeanOutputConverter implements StructuredOutputConverter { return this.jsonSchema; } + private static class CustomizedTypeReference extends TypeReference { + + private final Type type; + + CustomizedTypeReference(ParameterizedTypeReference typeRef) { + this.type = typeRef.getType(); + } + + @Override + public Type getType() { + return this.type; + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java index 9afbc14e3..eea1e89ce 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/FormatProvider.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java index 65a214f02..3a16e2758 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/ListOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java index 682e6fcf7..f5100ebfb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/MapOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.nio.charset.StandardCharsets; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md b/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md index 125f9b4f2..3f9c03d0b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/README.md @@ -8,6 +8,8 @@ It may be a correct JSON, but it isn’t a JSON data structure. It is just a string. Also, asking "for JSON" as part of the prompt isn’t 100% accurate. -This intricacy has led to the emergence of a specialized field involving the creation of prompts to yield the intended output, followed by converting the resulting simple string into a usable data structure for application integration. +This intricacy has led to the emergence of a specialized field involving the creation of prompts to yield the intended +output, followed by converting the resulting simple string into a usable data structure for application integration. -Structure output conversion employs meticulously crafted prompts, often necessitating multiple interactions with the model to achieve the desired formatting. +Structure output conversion employs meticulously crafted prompts, often necessitating multiple interactions with the +model to achieve the desired formatting. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java index 4756468a8..40a5f5483 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/StructuredOutputConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import org.springframework.core.convert.converter.Converter; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java index 8b5ec57c3..a0a2e1d2c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/ContentFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index 7db065225..570b3afb4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.ArrayList; @@ -30,7 +31,7 @@ import org.springframework.util.Assert; /** * @author Christian Tzolov */ -public class DefaultContentFormatter implements ContentFormatter { +public final class DefaultContentFormatter implements ContentFormatter { private static final String TEMPLATE_CONTENT_PLACEHOLDER = "{content}"; @@ -74,6 +75,14 @@ public class DefaultContentFormatter implements ContentFormatter { */ private final List excludedEmbedMetadataKeys; + private DefaultContentFormatter(Builder builder) { + this.metadataTemplate = builder.metadataTemplate; + this.metadataSeparator = builder.metadataSeparator; + this.textTemplate = builder.textTemplate; + this.excludedInferenceMetadataKeys = builder.excludedInferenceMetadataKeys; + this.excludedEmbedMetadataKeys = builder.excludedEmbedMetadataKeys; + } + /** * Start building a new configuration. * @return The entry point for creating a new configuration. @@ -90,15 +99,71 @@ public class DefaultContentFormatter implements ContentFormatter { return builder().build(); } - private DefaultContentFormatter(Builder builder) { - this.metadataTemplate = builder.metadataTemplate; - this.metadataSeparator = builder.metadataSeparator; - this.textTemplate = builder.textTemplate; - this.excludedInferenceMetadataKeys = builder.excludedInferenceMetadataKeys; - this.excludedEmbedMetadataKeys = builder.excludedEmbedMetadataKeys; + @Override + public String format(Document document, MetadataMode metadataMode) { + + var metadata = metadataFilter(document.getMetadata(), metadataMode); + + var metadataText = metadata.entrySet() + .stream() + .map(metadataEntry -> this.metadataTemplate.replace(TEMPLATE_KEY_PLACEHOLDER, metadataEntry.getKey()) + .replace(TEMPLATE_VALUE_PLACEHOLDER, metadataEntry.getValue().toString())) + .collect(Collectors.joining(this.metadataSeparator)); + + return this.textTemplate.replace(TEMPLATE_METADATA_STRING_PLACEHOLDER, metadataText) + .replace(TEMPLATE_CONTENT_PLACEHOLDER, document.getContent()); } - public static class Builder { + /** + * Filters the metadata by the configured MetadataMode. + * @param metadata Document metadata. + * @return Returns the filtered by configured mode metadata. + */ + protected Map metadataFilter(Map metadata, MetadataMode metadataMode) { + + if (metadataMode == MetadataMode.ALL) { + return new HashMap(metadata); + } + if (metadataMode == MetadataMode.NONE) { + return new HashMap(Collections.emptyMap()); + } + + Set usableMetadataKeys = new HashSet<>(metadata.keySet()); + + if (metadataMode == MetadataMode.INFERENCE) { + usableMetadataKeys.removeAll(this.excludedInferenceMetadataKeys); + } + else if (metadataMode == MetadataMode.EMBED) { + usableMetadataKeys.removeAll(this.excludedEmbedMetadataKeys); + } + + return new HashMap(metadata.entrySet() + .stream() + .filter(e -> usableMetadataKeys.contains(e.getKey())) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + } + + public String getMetadataTemplate() { + return this.metadataTemplate; + } + + public String getMetadataSeparator() { + return this.metadataSeparator; + } + + public String getTextTemplate() { + return this.textTemplate; + } + + public List getExcludedInferenceMetadataKeys() { + return Collections.unmodifiableList(this.excludedInferenceMetadataKeys); + } + + public List getExcludedEmbedMetadataKeys() { + return Collections.unmodifiableList(this.excludedEmbedMetadataKeys); + } + + public static final class Builder { private String metadataTemplate = DEFAULT_METADATA_TEMPLATE; @@ -199,68 +264,4 @@ public class DefaultContentFormatter implements ContentFormatter { } - @Override - public String format(Document document, MetadataMode metadataMode) { - - var metadata = metadataFilter(document.getMetadata(), metadataMode); - - var metadataText = metadata.entrySet() - .stream() - .map(metadataEntry -> this.metadataTemplate.replace(TEMPLATE_KEY_PLACEHOLDER, metadataEntry.getKey()) - .replace(TEMPLATE_VALUE_PLACEHOLDER, metadataEntry.getValue().toString())) - .collect(Collectors.joining(this.metadataSeparator)); - - return this.textTemplate.replace(TEMPLATE_METADATA_STRING_PLACEHOLDER, metadataText) - .replace(TEMPLATE_CONTENT_PLACEHOLDER, document.getContent()); - } - - /** - * Filters the metadata by the configured MetadataMode. - * @param metadata Document metadata. - * @return Returns the filtered by configured mode metadata. - */ - protected Map metadataFilter(Map metadata, MetadataMode metadataMode) { - - if (metadataMode == MetadataMode.ALL) { - return new HashMap(metadata); - } - if (metadataMode == MetadataMode.NONE) { - return new HashMap(Collections.emptyMap()); - } - - Set usableMetadataKeys = new HashSet<>(metadata.keySet()); - - if (metadataMode == MetadataMode.INFERENCE) { - usableMetadataKeys.removeAll(this.excludedInferenceMetadataKeys); - } - else if (metadataMode == MetadataMode.EMBED) { - usableMetadataKeys.removeAll(this.excludedEmbedMetadataKeys); - } - - return new HashMap(metadata.entrySet() - .stream() - .filter(e -> usableMetadataKeys.contains(e.getKey())) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); - } - - public String getMetadataTemplate() { - return this.metadataTemplate; - } - - public String getMetadataSeparator() { - return this.metadataSeparator; - } - - public String getTextTemplate() { - return this.textTemplate; - } - - public List getExcludedInferenceMetadataKeys() { - return Collections.unmodifiableList(this.excludedInferenceMetadataKeys); - } - - public List getExcludedEmbedMetadataKeys() { - return Collections.unmodifiableList(this.excludedEmbedMetadataKeys); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 7722a4dfe..e666857dd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.document; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.document.id.IdGenerator; -import org.springframework.ai.document.id.RandomIdGenerator; -import org.springframework.ai.model.Media; -import org.springframework.ai.model.MediaContent; -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; +package org.springframework.ai.document; import java.util.ArrayList; import java.util.Collection; @@ -32,6 +22,18 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.document.id.IdGenerator; +import org.springframework.ai.document.id.RandomIdGenerator; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.MediaContent; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + /** * A document is a container for the content and metadata of a document. It also contains * the document's unique ID and an optional embedding. @@ -48,12 +50,6 @@ public class Document implements MediaContent { */ private final String id; - /** - * Metadata for the document. It should not be nested and values should be restricted - * to string, int, float, boolean for simple use with Vector Dbs. - */ - private Map metadata; - /** * Document content. */ @@ -61,6 +57,12 @@ public class Document implements MediaContent { private final Collection media; + /** + * Metadata for the document. It should not be nested and values should be restricted + * to string, int, float, boolean for simple use with Vector Dbs. + */ + private Map metadata; + /** * Embedding of the document. Note: ephemeral field. */ @@ -109,6 +111,120 @@ public class Document implements MediaContent { return new Builder(); } + public String getId() { + return this.id; + } + + @Override + public String getContent() { + return this.content; + } + + @Override + public Collection getMedia() { + return this.media; + } + + @JsonIgnore + public String getFormattedContent() { + return this.getFormattedContent(MetadataMode.ALL); + } + + public String getFormattedContent(MetadataMode metadataMode) { + Assert.notNull(metadataMode, "Metadata mode must not be null"); + return this.contentFormatter.format(this, metadataMode); + } + + /** + * Helper content extractor that uses and external {@link ContentFormatter}. + */ + public String getFormattedContent(ContentFormatter formatter, MetadataMode metadataMode) { + Assert.notNull(formatter, "formatter must not be null"); + Assert.notNull(metadataMode, "Metadata mode must not be null"); + return formatter.format(this, metadataMode); + } + + @Override + public Map getMetadata() { + return this.metadata; + } + + public float[] getEmbedding() { + return this.embedding; + } + + public void setEmbedding(float[] embedding) { + Assert.notNull(embedding, "embedding must not be null"); + this.embedding = embedding; + } + + public ContentFormatter getContentFormatter() { + return this.contentFormatter; + } + + /** + * Replace the document's {@link ContentFormatter}. + * @param contentFormatter new formatter to use. + */ + public void setContentFormatter(ContentFormatter contentFormatter) { + this.contentFormatter = contentFormatter; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.id == null) ? 0 : this.id.hashCode()); + result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode()); + result = prime * result + ((this.content == null) ? 0 : this.content.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Document other = (Document) obj; + if (this.id == null) { + if (other.id != null) { + return false; + } + } + else if (!this.id.equals(other.id)) { + return false; + } + if (this.metadata == null) { + if (other.metadata != null) { + return false; + } + } + else if (!this.metadata.equals(other.metadata)) { + return false; + } + if (this.content == null) { + if (other.content != null) { + return false; + } + } + else if (!this.content.equals(other.content)) { + return false; + } + return true; + } + + @Override + public String toString() { + return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content + + '\'' + ", media=" + this.media + '}'; + } + public static class Builder { private String id; @@ -166,116 +282,11 @@ public class Document implements MediaContent { public Document build() { if (!StringUtils.hasText(this.id)) { - this.id = this.idGenerator.generateId(content, metadata); + this.id = this.idGenerator.generateId(this.content, this.metadata); } - return new Document(id, content, media, metadata); + return new Document(this.id, this.content, this.media, this.metadata); } } - public String getId() { - return id; - } - - @Override - public String getContent() { - return this.content; - } - - @Override - public Collection getMedia() { - return this.media; - } - - @JsonIgnore - public String getFormattedContent() { - return this.getFormattedContent(MetadataMode.ALL); - } - - public String getFormattedContent(MetadataMode metadataMode) { - Assert.notNull(metadataMode, "Metadata mode must not be null"); - return this.contentFormatter.format(this, metadataMode); - } - - /** - * Helper content extractor that uses and external {@link ContentFormatter}. - */ - public String getFormattedContent(ContentFormatter formatter, MetadataMode metadataMode) { - Assert.notNull(formatter, "formatter must not be null"); - Assert.notNull(metadataMode, "Metadata mode must not be null"); - return formatter.format(this, metadataMode); - } - - public void setEmbedding(float[] embedding) { - Assert.notNull(embedding, "embedding must not be null"); - this.embedding = embedding; - } - - /** - * Replace the document's {@link ContentFormatter}. - * @param contentFormatter new formatter to use. - */ - public void setContentFormatter(ContentFormatter contentFormatter) { - this.contentFormatter = contentFormatter; - } - - @Override - public Map getMetadata() { - return this.metadata; - } - - public float[] getEmbedding() { - return this.embedding; - } - - public ContentFormatter getContentFormatter() { - return contentFormatter; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((id == null) ? 0 : id.hashCode()); - result = prime * result + ((metadata == null) ? 0 : metadata.hashCode()); - result = prime * result + ((content == null) ? 0 : content.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - Document other = (Document) obj; - if (id == null) { - if (other.id != null) - return false; - } - else if (!id.equals(other.id)) - return false; - if (metadata == null) { - if (other.metadata != null) - return false; - } - else if (!metadata.equals(other.metadata)) - return false; - if (content == null) { - if (other.content != null) - return false; - } - else if (!content.equals(other.content)) - return false; - return true; - } - - @Override - public String toString() { - return "Document{" + "id='" + id + '\'' + ", metadata=" + metadata + ", content='" + content + '\'' + ", media=" - + media + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java index 75b4fe2b2..f6179ca4e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java index 50d0b4b13..618af3fc6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java index 8c325a7bd..6f17faf7c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentTransformer.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java index 31aeaf905..a85fb49e5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentWriter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java b/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java index 3d32a2b5d..733e1cbfb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/MetadataMode.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; public enum MetadataMode { - ALL, EMBED, INFERENCE, NONE; + ALL, EMBED, INFERENCE, NONE -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java index 198c114d5..f9c43726b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/IdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java index 1302a6dd3..ca561b035 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/JdkSha256HexIdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.io.ByteArrayOutputStream; @@ -86,7 +87,7 @@ public class JdkSha256HexIdGenerator implements IdGenerator { MessageDigest getMessageDigest() { try { - return (MessageDigest) messageDigest.clone(); + return (MessageDigest) this.messageDigest.clone(); } catch (CloneNotSupportedException e) { throw new RuntimeException("Unsupported clone for MessageDigest.", e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java index 0920e9a04..8c50c2e56 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/id/RandomIdGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.util.UUID; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java index 1c6c1c374..8165e5d7a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.io.IOException; @@ -31,10 +32,10 @@ import org.springframework.core.io.DefaultResourceLoader; */ public abstract class AbstractEmbeddingModel implements EmbeddingModel { - protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1); - private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); + protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1); + /** * Return the dimension of the requested embedding generative name. If the generative * name is unknown uses the EmbeddingModel to perform a dummy EmbeddingModel#embed and diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java index 4f73cab06..e354f1da8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java index eb4a83540..6237b98d9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.Model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java index 8227b2910..6fc754c11 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.Arrays; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java index daaa20d0e..1dabb36d4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.Objects; @@ -56,7 +57,7 @@ public class Embedding implements ModelResult { */ @Override public float[] getOutput() { - return embedding; + return this.embedding; } /** @@ -75,17 +76,19 @@ public class Embedding implements ModelResult { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } Embedding other = (Embedding) o; return Objects.equals(this.embedding, other.embedding) && Objects.equals(this.index, other.index); } @Override public int hashCode() { - return Objects.hash(embedding, index); + return Objects.hash(this.embedding, this.index); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java index 874fadfed..e4785b867 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; +import java.util.ArrayList; +import java.util.List; + import org.springframework.ai.document.Document; import org.springframework.ai.model.Model; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.List; - /** * EmbeddingModel is a generic interface for embedding models. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java index 3fac81190..f7461249f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java index ab13dff1a..cdb4fb999 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; /** * @author Thomas Vitale * @since 1.0.0 */ -public class EmbeddingOptionsBuilder { +public final class EmbeddingOptionsBuilder { private final DefaultEmbeddingOptions embeddingOptions = new DefaultEmbeddingOptions(); @@ -31,17 +32,17 @@ public class EmbeddingOptionsBuilder { } public EmbeddingOptionsBuilder withModel(String model) { - embeddingOptions.setModel(model); + this.embeddingOptions.setModel(model); return this; } public EmbeddingOptionsBuilder withDimensions(Integer dimensions) { - embeddingOptions.setDimensions(dimensions); + this.embeddingOptions.setDimensions(dimensions); return this; } public EmbeddingOptions build() { - return embeddingOptions; + return this.embeddingOptions; } private static class DefaultEmbeddingOptions implements EmbeddingOptions { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java index e5512bfe2..70429783e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java index b89262567..2ad2afac3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; @@ -58,13 +59,13 @@ public class EmbeddingResponse implements ModelResponse { * @return Get the embedding metadata. */ public EmbeddingResponseMetadata getMetadata() { - return metadata; + return this.metadata; } @Override public Embedding getResult() { - Assert.notEmpty(embeddings, "No embedding data available."); - return embeddings.get(0); + Assert.notEmpty(this.embeddings, "No embedding data available."); + return this.embeddings.get(0); } /** @@ -72,27 +73,29 @@ public class EmbeddingResponse implements ModelResponse { */ @Override public List getResults() { - return embeddings; + return this.embeddings; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } EmbeddingResponse that = (EmbeddingResponse) o; - return Objects.equals(embeddings, that.embeddings) && Objects.equals(metadata, that.metadata); + return Objects.equals(this.embeddings, that.embeddings) && Objects.equals(this.metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(embeddings, metadata); + return Objects.hash(this.embeddings, this.metadata); } @Override public String toString() { - return "EmbeddingResult{" + "data=" + embeddings + ", metadata=" + metadata + '}'; + return "EmbeddingResult{" + "data=" + this.embeddings + ", metadata=" + this.metadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java index 335ac0ae2..a9440dc7a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; +import java.util.Map; + import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.model.AbstractResponseMetadata; import org.springframework.ai.model.ResponseMetadata; -import java.util.Map; - /** * Common AI provider metadata returned in an embedding response. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java index 9b7df810b..eb0dfead4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import org.springframework.ai.model.ResultMetadata; @@ -27,12 +28,6 @@ public class EmbeddingResultMetadata implements ResultMetadata { public static EmbeddingResultMetadata EMPTY = new EmbeddingResultMetadata(); - public enum ModalityType { - - TEXT, IMAGE, AUDIO, VIDEO; - - } - /** * The {@link MimeType} of the source data used to generate the embedding. */ @@ -75,6 +70,12 @@ public class EmbeddingResultMetadata implements ResultMetadata { return this.documentData; } + public enum ModalityType { + + TEXT, IMAGE, AUDIO, VIDEO + + } + public static class ModalityUtils { private static MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index 2ff2dce0e..298278c35 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.ArrayList; @@ -20,13 +21,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import com.knuddels.jtokkit.api.EncodingType; + import org.springframework.ai.document.ContentFormatter; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator; - -import com.knuddels.jtokkit.api.EncodingType; import org.springframework.util.Assert; /** @@ -144,7 +145,7 @@ public class TokenCountBatchingStrategy implements BatchingStrategy { for (Document document : documentTokens.keySet()) { Integer tokenCount = documentTokens.get(document); - if (currentSize + tokenCount > maxInputTokenCount) { + if (currentSize + tokenCount > this.maxInputTokenCount) { batches.add(currentBatch); currentBatch = new ArrayList<>(); currentSize = 0; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java index 6e4269fe9..6949f0e00 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; + import org.springframework.util.StringUtils; /** @@ -27,14 +29,14 @@ import org.springframework.util.StringUtils; */ public class DefaultEmbeddingModelObservationConvention implements EmbeddingModelObservationConvention { + public static final String DEFAULT_NAME = "gen_ai.client.operation"; + private static final KeyValue REQUEST_MODEL_NONE = KeyValue .of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); private static final KeyValue RESPONSE_MODEL_NONE = KeyValue .of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL, KeyValue.NONE_VALUE); - public static final String DEFAULT_NAME = "gen_ai.client.operation"; - @Override public String getName() { return DEFAULT_NAME; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java index 8d5fb754b..84a7b8f04 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; + import org.springframework.ai.model.observation.ModelUsageMetricsGenerator; /** @@ -38,7 +40,8 @@ public class EmbeddingModelMeterObservationHandler implements ObservationHandler public void onStop(EmbeddingModelObservationContext context) { if (context.getResponse() != null && context.getResponse().getMetadata() != null && context.getResponse().getMetadata().getUsage() != null) { - ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, meterRegistry); + ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, + this.meterRegistry); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java index 2b6b09c67..9b46135ae 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; import org.springframework.ai.embedding.EmbeddingOptions; @@ -44,15 +45,15 @@ public class EmbeddingModelObservationContext extends ModelObservationContext getDataList() { - return dataList; + return this.dataList; } public String getResponseContent() { - return responseContent; + return this.responseContent; } @Override public String toString() { - return "EvaluationRequest{" + "userText='" + userText + '\'' + ", dataList=" + dataList + ", chatResponse=" - + responseContent + '}'; + return "EvaluationRequest{" + "userText='" + this.userText + '\'' + ", dataList=" + this.dataList + + ", chatResponse=" + this.responseContent + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof EvaluationRequest that)) + } + if (!(o instanceof EvaluationRequest that)) { return false; - return Objects.equals(userText, that.userText) && Objects.equals(dataList, that.dataList) - && Objects.equals(responseContent, that.responseContent); + } + return Objects.equals(this.userText, that.userText) && Objects.equals(this.dataList, that.dataList) + && Objects.equals(this.responseContent, that.responseContent); } @Override public int hashCode() { - return Objects.hash(userText, dataList, responseContent); + return Objects.hash(this.userText, this.dataList, this.responseContent); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java index f866cb5e2..ead7fa565 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/EvaluationResponse.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.evaluation; import java.util.Map; @@ -29,40 +45,42 @@ public class EvaluationResponse { } public boolean isPass() { - return pass; + return this.pass; } public float getScore() { - return score; + return this.score; } public String getFeedback() { - return feedback; + return this.feedback; } public Map getMetadata() { - return metadata; + return this.metadata; } @Override public String toString() { - return "EvaluationResponse{" + "pass=" + pass + ", score=" + score + ", feedback='" + feedback + '\'' - + ", metadata=" + metadata + '}'; + return "EvaluationResponse{" + "pass=" + this.pass + ", score=" + this.score + ", feedback='" + this.feedback + + '\'' + ", metadata=" + this.metadata + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof EvaluationResponse that)) + } + if (!(o instanceof EvaluationResponse that)) { return false; - return pass == that.pass && Float.compare(score, that.score) == 0 && Objects.equals(feedback, that.feedback) - && Objects.equals(metadata, that.metadata); + } + return this.pass == that.pass && Float.compare(this.score, that.score) == 0 + && Objects.equals(this.feedback, that.feedback) && Objects.equals(this.metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(pass, score, feedback, metadata); + return Objects.hash(this.pass, this.score, this.feedback, this.metadata); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java index b14fe2adb..9b12205d3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/Evaluator.java @@ -1,11 +1,27 @@ -package org.springframework.ai.evaluation; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.model.Content; -import org.springframework.util.StringUtils; +package org.springframework.ai.evaluation; import java.util.List; import java.util.stream.Collectors; +import org.springframework.ai.model.Content; +import org.springframework.util.StringUtils; + @FunctionalInterface public interface Evaluator { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java index eb16f66c4..77bd676ed 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/FactCheckingEvaluator.java @@ -1,9 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.evaluation; -import org.springframework.ai.chat.client.ChatClient; - import java.util.Collections; +import org.springframework.ai.chat.client.ChatClient; + /** * The FactCheckingEvaluator class implements a method for evaluating the factual accuracy * of Large Language Model (LLM) responses against provided context. @@ -48,8 +64,8 @@ import java.util.Collections; public class FactCheckingEvaluator implements Evaluator { private static final String DEFAULT_EVALUATION_PROMPT_TEXT = """ - Document: \\n {document}\\n - Claim: \\n {claim} + Document: \\n {document}\\n + Claim: \\n {claim} """; private final ChatClient.Builder chatClientBuilder; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java index 5a0ec203a..b85591333 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java @@ -1,21 +1,37 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.evaluation; -import org.springframework.ai.chat.client.ChatClient; - import java.util.Collections; +import org.springframework.ai.chat.client.ChatClient; + public class RelevancyEvaluator implements Evaluator { private static final String DEFAULT_EVALUATION_PROMPT_TEXT = """ - Your task is to evaluate if the response for the query - is in line with the context information provided.\\n - You have two options to answer. Either YES/ NO.\\n - Answer - YES, if the response for the query - is in line with context information otherwise NO.\\n - Query: \\n {query}\\n - Response: \\n {response}\\n - Context: \\n {context}\\n - Answer: " + Your task is to evaluate if the response for the query + is in line with the context information provided.\\n + You have two options to answer. Either YES/ NO.\\n + Answer - YES, if the response for the query + is in line with context information otherwise NO.\\n + Query: \\n {query}\\n + Response: \\n {response}\\n + Context: \\n {context}\\n + Answer: " """; private final ChatClient.Builder chatClientBuilder; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java index 8adc3677e..bf1f683a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.Objects; @@ -35,7 +36,7 @@ public class Image { } public String getUrl() { - return url; + return this.url; } public void setUrl(String url) { @@ -43,7 +44,7 @@ public class Image { } public String getB64Json() { - return b64Json; + return this.b64Json; } public void setB64Json(String b64Json) { @@ -52,21 +53,23 @@ public class Image { @Override public String toString() { - return "Image{" + "url='" + url + '\'' + ", b64Json='" + b64Json + '\'' + '}'; + return "Image{" + "url='" + this.url + '\'' + ", b64Json='" + this.b64Json + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Image image)) + } + if (!(o instanceof Image image)) { return false; - return Objects.equals(url, image.url) && Objects.equals(b64Json, image.b64Json); + } + return Objects.equals(this.url, image.url) && Objects.equals(this.b64Json, image.b64Json); } @Override public int hashCode() { - return Objects.hash(url, b64Json); + return Objects.hash(this.url, this.b64Json); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java index 431afd813..3f9425f55 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelResult; @@ -34,17 +35,18 @@ public class ImageGeneration implements ModelResult { @Override public Image getOutput() { - return image; + return this.image; } @Override public ImageGenerationMetadata getMetadata() { - return imageGenerationMetadata; + return this.imageGenerationMetadata; } @Override public String toString() { - return "ImageGeneration{" + "imageGenerationMetadata=" + imageGenerationMetadata + ", image=" + image + '}'; + return "ImageGeneration{" + "imageGenerationMetadata=" + this.imageGenerationMetadata + ", image=" + this.image + + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java index 164f781d1..7a513390c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ResultMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java index 2b298bb07..72825b1a4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.Objects; @@ -33,30 +34,32 @@ public class ImageMessage { } public String getText() { - return text; + return this.text; } public Float getWeight() { - return weight; + return this.weight; } @Override public String toString() { - return "ImageMessage{" + "text='" + text + '\'' + ", weight=" + weight + '}'; + return "ImageMessage{" + "text='" + this.text + '\'' + ", weight=" + this.weight + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImageMessage that)) + } + if (!(o instanceof ImageMessage that)) { return false; - return Objects.equals(text, that.text) && Objects.equals(weight, that.weight); + } + return Objects.equals(this.text, that.text) && Objects.equals(this.weight, that.weight); } @Override public int hashCode() { - return Objects.hash(text, weight); + return Objects.hash(this.text, this.weight); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java index 493da50bf..466931a68 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.Model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java index af8962029..435f6fc62 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index 7917f3df5..30f1f0105 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,54 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; -public class ImageOptionsBuilder { +public final class ImageOptionsBuilder { + + private final DefaultImageModelOptions options = new DefaultImageModelOptions(); + + private ImageOptionsBuilder() { + + } + + public static ImageOptionsBuilder builder() { + return new ImageOptionsBuilder(); + } + + public ImageOptionsBuilder withN(Integer n) { + this.options.setN(n); + return this; + } + + public ImageOptionsBuilder withModel(String model) { + this.options.setModel(model); + return this; + } + + public ImageOptionsBuilder withResponseFormat(String responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public ImageOptionsBuilder withWidth(Integer width) { + this.options.setWidth(width); + return this; + } + + public ImageOptionsBuilder withHeight(Integer height) { + this.options.setHeight(height); + return this; + } + + public ImageOptionsBuilder withStyle(String style) { + this.options.setStyle(style); + return this; + } + + public ImageOptions build() { + return this.options; + } private static class DefaultImageModelOptions implements ImageOptions { @@ -33,7 +78,7 @@ public class ImageOptionsBuilder { @Override public Integer getN() { - return n; + return this.n; } public void setN(Integer n) { @@ -42,7 +87,7 @@ public class ImageOptionsBuilder { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -51,7 +96,7 @@ public class ImageOptionsBuilder { @Override public String getResponseFormat() { - return responseFormat; + return this.responseFormat; } public void setResponseFormat(String responseFormat) { @@ -60,7 +105,7 @@ public class ImageOptionsBuilder { @Override public Integer getWidth() { - return width; + return this.width; } public void setWidth(Integer width) { @@ -69,7 +114,7 @@ public class ImageOptionsBuilder { @Override public Integer getHeight() { - return height; + return this.height; } public void setHeight(Integer height) { @@ -78,7 +123,7 @@ public class ImageOptionsBuilder { @Override public String getStyle() { - return style; + return this.style; } public void setStyle(String style) { @@ -87,48 +132,4 @@ public class ImageOptionsBuilder { } - private final DefaultImageModelOptions options = new DefaultImageModelOptions(); - - private ImageOptionsBuilder() { - - } - - public static ImageOptionsBuilder builder() { - return new ImageOptionsBuilder(); - } - - public ImageOptionsBuilder withN(Integer n) { - options.setN(n); - return this; - } - - public ImageOptionsBuilder withModel(String model) { - options.setModel(model); - return this; - } - - public ImageOptionsBuilder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public ImageOptionsBuilder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public ImageOptionsBuilder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public ImageOptionsBuilder withStyle(String style) { - options.setStyle(style); - return this; - } - - public ImageOptions build() { - return options; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java index 59ac64c81..a212c2cf4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.image; -import org.springframework.ai.model.ModelRequest; +package org.springframework.ai.image; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelRequest; + public class ImagePrompt implements ModelRequest> { private final List messages; @@ -50,31 +51,34 @@ public class ImagePrompt implements ModelRequest> { @Override public List getInstructions() { - return messages; + return this.messages; } @Override public ImageOptions getOptions() { - return imageModelOptions; + return this.imageModelOptions; } @Override public String toString() { - return "NewImagePrompt{" + "messages=" + messages + ", imageModelOptions=" + imageModelOptions + '}'; + return "NewImagePrompt{" + "messages=" + this.messages + ", imageModelOptions=" + this.imageModelOptions + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImagePrompt that)) + } + if (!(o instanceof ImagePrompt that)) { return false; - return Objects.equals(messages, that.messages) && Objects.equals(imageModelOptions, that.imageModelOptions); + } + return Objects.equals(this.messages, that.messages) + && Objects.equals(this.imageModelOptions, that.imageModelOptions); } @Override public int hashCode() { - return Objects.hash(messages, imageModelOptions); + return Objects.hash(this.messages, this.imageModelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java index b6d6c87b8..c4605d818 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import java.util.List; @@ -67,7 +68,7 @@ public class ImageResponse implements ModelResponse { */ @Override public List getResults() { - return imageGenerations; + return this.imageGenerations; } /** @@ -78,7 +79,7 @@ public class ImageResponse implements ModelResponse { if (CollectionUtils.isEmpty(this.imageGenerations)) { return null; } - return imageGenerations.get(0); + return this.imageGenerations.get(0); } /** @@ -87,28 +88,30 @@ public class ImageResponse implements ModelResponse { */ @Override public ImageResponseMetadata getMetadata() { - return imageResponseMetadata; + return this.imageResponseMetadata; } @Override public String toString() { - return "ImageResponse [" + "imageResponseMetadata=" + imageResponseMetadata + ", imageGenerations=" - + imageGenerations + "]"; + return "ImageResponse [" + "imageResponseMetadata=" + this.imageResponseMetadata + ", imageGenerations=" + + this.imageGenerations + "]"; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ImageResponse that)) + } + if (!(o instanceof ImageResponse that)) { return false; - return Objects.equals(imageResponseMetadata, that.imageResponseMetadata) - && Objects.equals(imageGenerations, that.imageGenerations); + } + return Objects.equals(this.imageResponseMetadata, that.imageResponseMetadata) + && Objects.equals(this.imageGenerations, that.imageGenerations); } @Override public int hashCode() { - return Objects.hash(imageResponseMetadata, imageGenerations); + return Objects.hash(this.imageResponseMetadata, this.imageGenerations); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java index 5d694e548..816c92b28 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image; import org.springframework.ai.model.MutableResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java index 2413e51e4..35cb6f51b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/DefaultImageModelObservationConvention.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; + import org.springframework.util.StringUtils; /** @@ -27,11 +29,11 @@ import org.springframework.util.StringUtils; */ public class DefaultImageModelObservationConvention implements ImageModelObservationConvention { + public static final String DEFAULT_NAME = "gen_ai.client.operation"; + private static final KeyValue REQUEST_MODEL_NONE = KeyValue .of(ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); - public static final String DEFAULT_NAME = "gen_ai.client.operation"; - @Override public String getName() { return DEFAULT_NAME; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java index 52846a41d..34ba0f770 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/observation/ImageModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import org.springframework.ai.image.ImageOptions; @@ -40,19 +41,19 @@ public class ImageModelObservationContext extends ModelObservationContext doubleToFloat(final List doubles) { return doubles.stream().map(f -> f.floatValue()).toList(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java b/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java index 5391f1f98..fe5cd8212 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; +import java.io.IOException; +import java.net.URL; + import org.springframework.core.io.Resource; import org.springframework.util.Assert; import org.springframework.util.MimeType; -import java.io.IOException; -import java.net.URL; - /** * The Media class represents the data and metadata of a media attachment in a message. It * consists of a MIME type and the raw data. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java b/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java index 4b436e82a..933ded36b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/MediaContent.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; import java.util.Collection; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java b/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java index 1671a3dc4..391786543 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/Model.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java index 0335f341e..71382538b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java index 124818e80..10f54e02c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 8a6cc7d0d..e049fb17f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.beans.PropertyDescriptor; @@ -26,13 +27,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import org.springframework.ai.util.JacksonUtils; -import org.springframework.beans.BeanWrapper; -import org.springframework.beans.BeanWrapperImpl; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.ObjectUtils; - import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -53,6 +47,13 @@ import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; import com.github.victools.jsonschema.module.swagger2.Swagger2Module; +import org.springframework.ai.util.JacksonUtils; +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; + /** * Utility class for manipulating {@link ModelOptions} objects. * @@ -74,6 +75,10 @@ public abstract class ModelOptionsUtils { private static final AtomicReference SCHEMA_GENERATOR_CACHE = new AtomicReference<>(); + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + + }; + /** * Converts the given JSON string to a Map of String and Object. * @param json the JSON string to convert to a Map. @@ -88,9 +93,6 @@ public abstract class ModelOptionsUtils { } } - private static TypeReference> MAP_TYPE_REF = new TypeReference>() { - }; - /** * Converts the given JSON string to an Object of the given type. * @param the type of the object to return. @@ -193,6 +195,7 @@ public abstract class ModelOptionsUtils { try { String json = OBJECT_MAPPER.writeValueAsString(source); return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }) .entrySet() .stream() @@ -356,7 +359,7 @@ public abstract class ModelOptionsUtils { ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI - // version of it). + // version of it). toUpperCaseTypeValues(node); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java index 94c2e8aef..7b86a8507 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** @@ -40,4 +41,4 @@ public interface ModelRequest { */ ModelOptions getOptions(); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java index f4a9bf83a..5df8b8d2a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java index 6ee17815c..f28a9dfc1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java index ac0c9254e..106a90e68 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/MutableResponseMetadata.java @@ -1,7 +1,20 @@ -package org.springframework.ai.model; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import io.micrometer.common.lang.NonNull; -import io.micrometer.common.lang.Nullable; +package org.springframework.ai.model; import java.util.Collections; import java.util.Map; @@ -9,6 +22,9 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + public class MutableResponseMetadata implements ResponseMetadata { private final Map map = new ConcurrentHashMap<>(); @@ -120,7 +136,7 @@ public class MutableResponseMetadata implements ResponseMetadata { } public Map getRawMap() { - return map; + return this.map; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java index 24e544d4f..7b63e91a4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.model; -import io.micrometer.common.lang.NonNull; -import io.micrometer.common.lang.Nullable; +package org.springframework.ai.model; import java.util.Map; import java.util.Set; import java.util.function.Supplier; +import io.micrometer.common.lang.NonNull; +import io.micrometer.common.lang.Nullable; + /** * Interface representing metadata associated with an AI model's response. * @@ -80,7 +81,7 @@ public interface ResponseMetadata { Set> entrySet(); - public Set keySet(); + Set keySet(); /** * Returns {@code true} if this map contains no key-value mappings. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java index 05d2aaca4..85f538d3b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java index 2c1de77a9..4b11f4fbb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModel.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import reactor.core.publisher.Flux; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index cd5d43be1..8a2c84aca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Function; -import org.springframework.ai.chat.model.ToolContext; -import org.springframework.util.Assert; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.util.Assert; + /** * Abstract implementation of the {@link FunctionCallback} for interacting with the * Model's function calling protocol and a {@link Function} wrapping the interaction with @@ -102,7 +103,7 @@ abstract class AbstractFunctionCallback implements BiFunction implements BiFunction implements BiFunction extends AbstractFunctionCallback { +public final class FunctionCallbackWrapper extends AbstractFunctionCallback { private final BiFunction biFunction; @@ -50,11 +51,6 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback Builder builder(BiFunction biFunction) { return new Builder<>(biFunction); } @@ -63,20 +59,32 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback(function); } + @Override + public O apply(I input, ToolContext context) { + return this.biFunction.apply(input, context); + } + public static class Builder { + private final BiFunction biFunction; + + private final Function function; + private String name; private String description; private Class inputType; - private final BiFunction biFunction; - - private final Function function; - private SchemaType schemaType = SchemaType.JSON_SCHEMA; + // By default the response is converted to a JSON string. + private Function responseConverter = ModelOptionsUtils::toJsonString; + + private String inputTypeSchema; + + private ObjectMapper objectMapper; + public Builder(BiFunction biFunction) { Assert.notNull(biFunction, "Function must not be null"); this.biFunction = biFunction; @@ -89,12 +97,16 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback responseConverter = ModelOptionsUtils::toJsonString; + @SuppressWarnings("unchecked") + private static Class resolveInputType(BiFunction biFunction) { + return (Class) TypeResolverHelper + .getBiFunctionInputClass((Class>) biFunction.getClass()); + } - private String inputTypeSchema; - - private ObjectMapper objectMapper; + @SuppressWarnings("unchecked") + private static Class resolveInputType(Function function) { + return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + } public Builder withName(String name) { Assert.hasText(name, "Name must not be empty"); @@ -173,17 +185,6 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback Class resolveInputType(BiFunction biFunction) { - return (Class) TypeResolverHelper - .getBiFunctionInputClass((Class>) biFunction.getClass()); - } - - @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); - } - } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index f61897993..722d7f24f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.List; @@ -26,6 +27,14 @@ import org.springframework.ai.chat.prompt.ChatOptions; */ public interface FunctionCallingOptions extends ChatOptions { + /** + * @return Returns FunctionCallingOptionsBuilder to create a new instance of + * FunctionCallingOptions. + */ + static FunctionCallingOptionsBuilder builder() { + return new FunctionCallingOptionsBuilder(); + } + /** * Function Callbacks to be registered with the ChatModel. For Prompt Options the * functionCallbacks are automatically enabled for the duration of the prompt @@ -67,16 +76,8 @@ public interface FunctionCallingOptions extends ChatOptions { } } - /** - * @return Returns FunctionCallingOptionsBuilder to create a new instance of - * FunctionCallingOptions. - */ - public static FunctionCallingOptionsBuilder builder() { - return new FunctionCallingOptionsBuilder(); - } - Map getToolContext(); void setToolContext(Map tooContext); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index ce84c8b04..b5304270e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.ArrayList; @@ -185,7 +186,7 @@ public class FunctionCallingOptionsBuilder { @Override public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -194,7 +195,7 @@ public class FunctionCallingOptionsBuilder { @Override public Double getFrequencyPenalty() { - return frequencyPenalty; + return this.frequencyPenalty; } public void setFrequencyPenalty(Double frequencyPenalty) { @@ -203,7 +204,7 @@ public class FunctionCallingOptionsBuilder { @Override public Integer getMaxTokens() { - return maxTokens; + return this.maxTokens; } public void setMaxTokens(Integer maxTokens) { @@ -212,7 +213,7 @@ public class FunctionCallingOptionsBuilder { @Override public Double getPresencePenalty() { - return presencePenalty; + return this.presencePenalty; } public void setPresencePenalty(Double presencePenalty) { @@ -221,7 +222,7 @@ public class FunctionCallingOptionsBuilder { @Override public List getStopSequences() { - return stopSequences; + return this.stopSequences; } public void setStopSequences(List stopSequences) { @@ -230,7 +231,7 @@ public class FunctionCallingOptionsBuilder { @Override public Double getTemperature() { - return temperature; + return this.temperature; } public void setTemperature(Double temperature) { @@ -239,7 +240,7 @@ public class FunctionCallingOptionsBuilder { @Override public Integer getTopK() { - return topK; + return this.topK; } public void setTopK(Integer topK) { @@ -248,7 +249,7 @@ public class FunctionCallingOptionsBuilder { @Override public Double getTopP() { - return topP; + return this.topP; } public void setTopP(Double topP) { @@ -257,7 +258,7 @@ public class FunctionCallingOptionsBuilder { @Override public Boolean getProxyToolCalls() { - return proxyToolCalls; + return this.proxyToolCalls; } public void setProxyToolCalls(Boolean proxyToolCalls) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java index f3f23868d..4df569657 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/ToolCallHelper.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.function; import java.util.ArrayList; @@ -7,6 +23,8 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -19,8 +37,6 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; - /** * Helper class that reuses the {@link AbstractToolCallSupport} to implement the function * call handling logic on the client side. Used when the withProxyToolCalls(true) option @@ -28,36 +44,6 @@ import reactor.core.publisher.Flux; */ public class ToolCallHelper extends AbstractToolCallSupport { - /** - * Helper used to provide only the function definition, without the actual function - * call implementation. - */ - public static record FunctionDefinition(String name, String description, - String inputTypeSchema) implements FunctionCallback { - - @Override - public String getName() { - return this.name(); - } - - @Override - public String getDescription() { - return this.description(); - } - - @Override - public String getInputTypeSchema() { - return this.inputTypeSchema(); - } - - @Override - public String call(String functionInput) { - throw new UnsupportedOperationException( - "FunctionDefinition provides only metadata. It doesn't implement the call method."); - } - - } - public ToolCallHelper() { this(null, PortableFunctionCallingOptions.builder().build(), List.of()); } @@ -163,4 +149,34 @@ public class ToolCallHelper extends AbstractToolCallSupport { return processCall(chatModel, prompt2, finishReasons, customFunction); } -} \ No newline at end of file + /** + * Helper used to provide only the function definition, without the actual function + * call implementation. + */ + public static record FunctionDefinition(String name, String description, + String inputTypeSchema) implements FunctionCallback { + + @Override + public String getName() { + return this.name(); + } + + @Override + public String getDescription() { + return this.description(); + } + + @Override + public String getInputTypeSchema() { + return this.inputTypeSchema(); + } + + @Override + public String call(String functionInput) { + throw new UnsupportedOperationException( + "FunctionDefinition provides only metadata. It doesn't implement the call method."); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index ae6176b78..8ff8584c4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.lang.reflect.GenericArrayType; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java index 21b92e75b..ff9e0a738 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java @@ -1,25 +1,24 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.observation; import java.util.List; import java.util.function.Consumer; -import org.springframework.util.Assert; - import io.micrometer.observation.Observation; import io.micrometer.observation.Observation.Context; import io.micrometer.observation.ObservationHandler; @@ -28,6 +27,8 @@ import io.micrometer.tracing.handler.TracingObservationHandler.TracingContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + /** * @author Christian Tzolov * @since 1.0.0 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java index 0c0ac6767..931fdf976 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelObservationContext.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; import io.micrometer.observation.Observation; + import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.lang.Nullable; import org.springframework.util.Assert; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java index dfd6e6c84..4a5eb8eaf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; +import java.util.ArrayList; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.observation.Observation; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; import org.springframework.ai.observation.conventions.AiObservationMetricNames; import org.springframework.ai.observation.conventions.AiTokenType; -import java.util.ArrayList; -import java.util.List; - /** * Generate metrics about the model usage in the context of an AI operation. * @@ -38,6 +40,9 @@ public final class ModelUsageMetricsGenerator { private static final String DESCRIPTION = "Measures number of input and output tokens used"; + private ModelUsageMetricsGenerator() { + } + public static void generate(Usage usage, Observation.Context context, MeterRegistry meterRegistry) { if (usage.getPromptTokens() != null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java index 1d5817736..867e8507c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.model.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java index 57c2b34db..207af410e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Provides a set of interfaces and classes for a generic API designed to interact with * various AI models. This package includes interfaces for handling AI model calls, @@ -23,4 +24,5 @@ * ensuring a broad applicability across diverse AI scenarios. * */ -package org.springframework.ai.model; \ No newline at end of file + +package org.springframework.ai.model; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java index 3a170028e..1d0be3c39 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,29 +26,29 @@ import java.util.Objects; * @author Ahmed Yousri * @since 1.0.0 */ -public class Categories { +public final class Categories { - private boolean sexual; + private final boolean sexual; - private boolean hate; + private final boolean hate; - private boolean harassment; + private final boolean harassment; - private boolean selfHarm; + private final boolean selfHarm; - private boolean sexualMinors; + private final boolean sexualMinors; - private boolean hateThreatening; + private final boolean hateThreatening; - private boolean violenceGraphic; + private final boolean violenceGraphic; - private boolean selfHarmIntent; + private final boolean selfHarmIntent; - private boolean selfHarmInstructions; + private final boolean selfHarmInstructions; - private boolean harassmentThreatening; + private final boolean harassmentThreatening; - private boolean violence; + private final boolean violence; private Categories(Builder builder) { this.sexual = builder.sexual; @@ -48,52 +64,84 @@ public class Categories { this.violence = builder.violence; } + public static Builder builder() { + return new Builder(); + } + public boolean isSexual() { - return sexual; + return this.sexual; } public boolean isHate() { - return hate; + return this.hate; } public boolean isHarassment() { - return harassment; + return this.harassment; } public boolean isSelfHarm() { - return selfHarm; + return this.selfHarm; } public boolean isSexualMinors() { - return sexualMinors; + return this.sexualMinors; } public boolean isHateThreatening() { - return hateThreatening; + return this.hateThreatening; } public boolean isViolenceGraphic() { - return violenceGraphic; + return this.violenceGraphic; } public boolean isSelfHarmIntent() { - return selfHarmIntent; + return this.selfHarmIntent; } public boolean isSelfHarmInstructions() { - return selfHarmInstructions; + return this.selfHarmInstructions; } public boolean isHarassmentThreatening() { - return harassmentThreatening; + return this.harassmentThreatening; } public boolean isViolence() { - return violence; + return this.violence; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Categories)) { + return false; + } + Categories that = (Categories) o; + return this.sexual == that.sexual && this.hate == that.hate && this.harassment == that.harassment + && this.selfHarm == that.selfHarm && this.sexualMinors == that.sexualMinors + && this.hateThreatening == that.hateThreatening && this.violenceGraphic == that.violenceGraphic + && this.selfHarmIntent == that.selfHarmIntent && this.selfHarmInstructions == that.selfHarmInstructions + && this.harassmentThreatening == that.harassmentThreatening && this.violence == that.violence; + } + + @Override + public int hashCode() { + return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, + this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, + this.harassmentThreatening, this.violence); + } + + @Override + public String toString() { + return "Categories{" + "sexual=" + this.sexual + ", hate=" + this.hate + ", harassment=" + this.harassment + + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; } public static class Builder { @@ -181,33 +229,4 @@ public class Categories { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof Categories)) - return false; - Categories that = (Categories) o; - return sexual == that.sexual && hate == that.hate && harassment == that.harassment && selfHarm == that.selfHarm - && sexualMinors == that.sexualMinors && hateThreatening == that.hateThreatening - && violenceGraphic == that.violenceGraphic && selfHarmIntent == that.selfHarmIntent - && selfHarmInstructions == that.selfHarmInstructions - && harassmentThreatening == that.harassmentThreatening && violence == that.violence; - } - - @Override - public int hashCode() { - return Objects.hash(sexual, hate, harassment, selfHarm, sexualMinors, hateThreatening, violenceGraphic, - selfHarmIntent, selfHarmInstructions, harassmentThreatening, violence); - } - - @Override - public String toString() { - return "Categories{" + "sexual=" + sexual + ", hate=" + hate + ", harassment=" + harassment + ", selfHarm=" - + selfHarm + ", sexualMinors=" + sexualMinors + ", hateThreatening=" + hateThreatening - + ", violenceGraphic=" + violenceGraphic + ", selfHarmIntent=" + selfHarmIntent - + ", selfHarmInstructions=" + selfHarmInstructions + ", harassmentThreatening=" + harassmentThreatening - + ", violence=" + violence + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java index 8429b7834..c96dc4e2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,29 +26,29 @@ import java.util.Objects; * @author Ahmed Yousri * @since 1.0.0 */ -public class CategoryScores { +public final class CategoryScores { - private double sexual; + private final double sexual; - private double hate; + private final double hate; - private double harassment; + private final double harassment; - private double selfHarm; + private final double selfHarm; - private double sexualMinors; + private final double sexualMinors; - private double hateThreatening; + private final double hateThreatening; - private double violenceGraphic; + private final double violenceGraphic; - private double selfHarmIntent; + private final double selfHarmIntent; - private double selfHarmInstructions; + private final double selfHarmInstructions; - private double harassmentThreatening; + private final double harassmentThreatening; - private double violence; + private final double violence; private CategoryScores(Builder builder) { this.sexual = builder.sexual; @@ -48,52 +64,89 @@ public class CategoryScores { this.violence = builder.violence; } + public static Builder builder() { + return new Builder(); + } + public double getSexual() { - return sexual; + return this.sexual; } public double getHate() { - return hate; + return this.hate; } public double getHarassment() { - return harassment; + return this.harassment; } public double getSelfHarm() { - return selfHarm; + return this.selfHarm; } public double getSexualMinors() { - return sexualMinors; + return this.sexualMinors; } public double getHateThreatening() { - return hateThreatening; + return this.hateThreatening; } public double getViolenceGraphic() { - return violenceGraphic; + return this.violenceGraphic; } public double getSelfHarmIntent() { - return selfHarmIntent; + return this.selfHarmIntent; } public double getSelfHarmInstructions() { - return selfHarmInstructions; + return this.selfHarmInstructions; } public double getHarassmentThreatening() { - return harassmentThreatening; + return this.harassmentThreatening; } public double getViolence() { - return violence; + return this.violence; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CategoryScores)) { + return false; + } + CategoryScores that = (CategoryScores) o; + return Double.compare(that.sexual, this.sexual) == 0 && Double.compare(that.hate, this.hate) == 0 + && Double.compare(that.harassment, this.harassment) == 0 + && Double.compare(that.selfHarm, this.selfHarm) == 0 + && Double.compare(that.sexualMinors, this.sexualMinors) == 0 + && Double.compare(that.hateThreatening, this.hateThreatening) == 0 + && Double.compare(that.violenceGraphic, this.violenceGraphic) == 0 + && Double.compare(that.selfHarmIntent, this.selfHarmIntent) == 0 + && Double.compare(that.selfHarmInstructions, this.selfHarmInstructions) == 0 + && Double.compare(that.harassmentThreatening, this.harassmentThreatening) == 0 + && Double.compare(that.violence, this.violence) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, + this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, + this.harassmentThreatening, this.violence); + } + + @Override + public String toString() { + return "CategoryScores{" + "sexual=" + this.sexual + ", hate=" + this.hate + ", harassment=" + this.harassment + + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; } public static class Builder { @@ -181,37 +234,4 @@ public class CategoryScores { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof CategoryScores)) - return false; - CategoryScores that = (CategoryScores) o; - return Double.compare(that.sexual, sexual) == 0 && Double.compare(that.hate, hate) == 0 - && Double.compare(that.harassment, harassment) == 0 && Double.compare(that.selfHarm, selfHarm) == 0 - && Double.compare(that.sexualMinors, sexualMinors) == 0 - && Double.compare(that.hateThreatening, hateThreatening) == 0 - && Double.compare(that.violenceGraphic, violenceGraphic) == 0 - && Double.compare(that.selfHarmIntent, selfHarmIntent) == 0 - && Double.compare(that.selfHarmInstructions, selfHarmInstructions) == 0 - && Double.compare(that.harassmentThreatening, harassmentThreatening) == 0 - && Double.compare(that.violence, violence) == 0; - } - - @Override - public int hashCode() { - return Objects.hash(sexual, hate, harassment, selfHarm, sexualMinors, hateThreatening, violenceGraphic, - selfHarmIntent, selfHarmInstructions, harassmentThreatening, violence); - } - - @Override - public String toString() { - return "CategoryScores{" + "sexual=" + sexual + ", hate=" + hate + ", harassment=" + harassment + ", selfHarm=" - + selfHarm + ", sexualMinors=" + sexualMinors + ", hateThreatening=" + hateThreatening - + ", violenceGraphic=" + violenceGraphic + ", selfHarmIntent=" + selfHarmIntent - + ", selfHarmInstructions=" + selfHarmInstructions + ", harassmentThreatening=" + harassmentThreatening - + ", violence=" + violence + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java index 98a4cf5fd..e73ebb232 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Generation.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,18 +52,18 @@ public class Generation implements ModelResult { @Override public Moderation getOutput() { - return moderation; + return this.moderation; } @Override public ModerationGenerationMetadata getMetadata() { - return moderationGenerationMetadata; + return this.moderationGenerationMetadata; } @Override public String toString() { - return "Generation{" + "moderationGenerationMetadata=" + moderationGenerationMetadata + ", moderation=" - + moderation + '}'; + return "Generation{" + "moderationGenerationMetadata=" + this.moderationGenerationMetadata + ", moderation=" + + this.moderation + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java index a98b94c72..7fe43a948 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Moderation.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Arrays; @@ -12,7 +28,7 @@ import java.util.Objects; * @author Ahmed Yousri * @since 1.0.0 */ -public class Moderation { +public final class Moderation { private final String id; @@ -26,42 +42,44 @@ public class Moderation { this.results = builder.moderationResultList; } + public static Builder builder() { + return new Builder(); + } + public String getId() { - return id; + return this.id; } public String getModel() { - return model; + return this.model; } public List getResults() { - return results; + return this.results; } @Override public String toString() { - return "Moderation{" + "id='" + id + '\'' + ", model='" + model + '\'' + ", results=" - + Arrays.toString(results.toArray()) + '}'; + return "Moderation{" + "id='" + this.id + '\'' + ", model='" + this.model + '\'' + ", results=" + + Arrays.toString(this.results.toArray()) + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof Moderation)) + } + if (!(o instanceof Moderation)) { return false; + } Moderation that = (Moderation) o; - return Objects.equals(id, that.id) && Objects.equals(model, that.model) - && Objects.equals(results, that.results); + return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model) + && Objects.equals(this.results, that.results); } @Override public int hashCode() { - return Objects.hash(id, model, results); - } - - public static Builder builder() { - return new Builder(); + return Objects.hash(this.id, this.model, this.results); } public static class Builder { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java index f186ec54d..5cb66e4ad 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationGenerationMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java index 455dd695c..335f9f5a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationMessage.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ public class ModerationMessage { } public String getText() { - return text; + return this.text; } public void setText(String text) { @@ -44,22 +44,24 @@ public class ModerationMessage { @Override public String toString() { - return "ModerationMessage{" + "text='" + text + '\'' + '}'; + return "ModerationMessage{" + "text='" + this.text + '\'' + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationMessage)) + } + if (!(o instanceof ModerationMessage)) { return false; + } ModerationMessage that = (ModerationMessage) o; - return Objects.equals(text, that.text); + return Objects.equals(this.text, that.text); } @Override public int hashCode() { - return Objects.hash(text); + return Objects.hash(this.text); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java index d7d47bf67..188fce42a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java index 57ac68f43..238989f07 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java index edacf2cba..b476d33c2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationOptionsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,22 +25,7 @@ package org.springframework.ai.moderation; * @author Ahmed Yousri * @since 1.0.0 */ -public class ModerationOptionsBuilder { - - private class ModerationModelOptionsImpl implements ModerationOptions { - - private String model; - - public void setModel(String model) { - this.model = model; - } - - @Override - public String getModel() { - return model; - } - - } +public final class ModerationOptionsBuilder { private final ModerationModelOptionsImpl options = new ModerationModelOptionsImpl(); @@ -53,12 +38,27 @@ public class ModerationOptionsBuilder { } public ModerationOptionsBuilder withModel(String model) { - options.setModel(model); + this.options.setModel(model); return this; } public ModerationOptions build() { - return options; + return this.options; + } + + private class ModerationModelOptionsImpl implements ModerationOptions { + + private String model; + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java index e783cb84f..02514d4b9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationPrompt.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ package org.springframework.ai.moderation; -import org.springframework.ai.model.ModelRequest; import java.util.Objects; +import org.springframework.ai.model.ModelRequest; + /** * Represents a prompt for moderation containing a single message and the options for the * moderation model. This class offers constructors to create a prompt from a single @@ -50,11 +51,11 @@ public class ModerationPrompt implements ModelRequest { @Override public ModerationMessage getInstructions() { - return message; + return this.message; } public ModerationOptions getOptions() { - return moderationModelOptions; + return this.moderationModelOptions; } public void setOptions(ModerationOptions moderationModelOptions) { @@ -63,23 +64,26 @@ public class ModerationPrompt implements ModelRequest { @Override public String toString() { - return "ModerationPrompt{" + "message=" + message + ", moderationModelOptions=" + moderationModelOptions + '}'; + return "ModerationPrompt{" + "message=" + this.message + ", moderationModelOptions=" + + this.moderationModelOptions + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationPrompt)) + } + if (!(o instanceof ModerationPrompt)) { return false; + } ModerationPrompt that = (ModerationPrompt) o; - return Objects.equals(message, that.message) - && Objects.equals(moderationModelOptions, that.moderationModelOptions); + return Objects.equals(this.message, that.message) + && Objects.equals(this.moderationModelOptions, that.moderationModelOptions); } @Override public int hashCode() { - return Objects.hash(message, moderationModelOptions); + return Objects.hash(this.message, this.moderationModelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java index 5da1469f2..043104436 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,11 @@ package org.springframework.ai.moderation; -import org.springframework.ai.model.ModelResponse; - import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; + /** * Represents a response from a moderation process, encapsulating the moderation metadata * and the generated content. This class provides access to both the single generation @@ -48,38 +48,40 @@ public class ModerationResponse implements ModelResponse { @Override public Generation getResult() { - return generations; + return this.generations; } @Override public List getResults() { - return List.of(generations); + return List.of(this.generations); } @Override public ModerationResponseMetadata getMetadata() { - return moderationResponseMetadata; + return this.moderationResponseMetadata; } @Override public String toString() { - return "ModerationResponse{" + "moderationResponseMetadata=" + moderationResponseMetadata + ", generations=" - + generations + '}'; + return "ModerationResponse{" + "moderationResponseMetadata=" + this.moderationResponseMetadata + + ", generations=" + this.generations + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof ModerationResponse that)) + } + if (!(o instanceof ModerationResponse that)) { return false; - return Objects.equals(moderationResponseMetadata, that.moderationResponseMetadata) - && Objects.equals(generations, that.generations); + } + return Objects.equals(this.moderationResponseMetadata, that.moderationResponseMetadata) + && Objects.equals(this.generations, that.generations); } @Override public int hashCode() { - return Objects.hash(moderationResponseMetadata, generations); + return Objects.hash(this.moderationResponseMetadata, this.generations); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java index 785d598c7..c32804dea 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResponseMetadata.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java index d7ec33e5d..29ea10830 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/ModerationResult.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.moderation; import java.util.Objects; @@ -10,7 +26,7 @@ import java.util.Objects; * @author Ahmed Yousri * @since 1.0.0 */ -public class ModerationResult { +public final class ModerationResult { private boolean flagged; @@ -24,8 +40,12 @@ public class ModerationResult { this.categoryScores = builder.categoryScores; } + public static Builder builder() { + return new Builder(); + } + public boolean isFlagged() { - return flagged; + return this.flagged; } public void setFlagged(boolean flagged) { @@ -33,7 +53,7 @@ public class ModerationResult { } public Categories getCategories() { - return categories; + return this.categories; } public void setCategories(Categories categories) { @@ -41,15 +61,35 @@ public class ModerationResult { } public CategoryScores getCategoryScores() { - return categoryScores; + return this.categoryScores; } public void setCategoryScores(CategoryScores categoryScores) { this.categoryScores = categoryScores; } - public static Builder builder() { - return new Builder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ModerationResult)) { + return false; + } + ModerationResult that = (ModerationResult) o; + return this.flagged == that.flagged && Objects.equals(this.categories, that.categories) + && Objects.equals(this.categoryScores, that.categoryScores); + } + + @Override + public int hashCode() { + return Objects.hash(this.flagged, this.categories, this.categoryScores); + } + + @Override + public String toString() { + return "ModerationResult{" + "flagged=" + this.flagged + ", categories=" + this.categories + ", categoryScores=" + + this.categoryScores + '}'; } public static class Builder { @@ -81,26 +121,4 @@ public class ModerationResult { } - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof ModerationResult)) - return false; - ModerationResult that = (ModerationResult) o; - return flagged == that.flagged && Objects.equals(categories, that.categories) - && Objects.equals(categoryScores, that.categoryScores); - } - - @Override - public int hashCode() { - return Objects.hash(flagged, categories, categoryScores); - } - - @Override - public String toString() { - return "ModerationResult{" + "flagged=" + flagged + ", categories=" + categories + ", categoryScores=" - + categoryScores + '}'; - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java index 68b1c3dff..a8707b1e6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/AiOperationMetadata.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation; import org.springframework.ai.observation.conventions.AiOperationType; @@ -41,7 +42,7 @@ public record AiOperationMetadata(String operationType, String provider) { return new Builder(); } - public static class Builder { + public static final class Builder { private String operationType; @@ -61,7 +62,7 @@ public record AiOperationMetadata(String operationType, String provider) { } public AiOperationMetadata build() { - return new AiOperationMetadata(operationType, provider); + return new AiOperationMetadata(this.operationType, this.provider); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java index 295664493..eea71318b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -141,7 +142,7 @@ public enum AiObservationAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java index c3f86f353..a44ce86ff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationEventNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -39,7 +40,7 @@ public enum AiObservationEventNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java index e8d828e7d..b729a8e5b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -41,7 +42,7 @@ public enum AiObservationMetricAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java index 358755319..fb8ca023a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiObservationMetricNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -39,7 +40,7 @@ public enum AiObservationMetricNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java index 85fa4f2a2..3defa442f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiOperationType.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 88d6a5aaf..63e2403c0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java index a8c2fec38..013731f94 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiTokenType.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -40,7 +41,7 @@ public enum AiTokenType { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java index 11c70f459..d23861d51 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/SpringAiKind.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java index bd869f0cb..e6f02cdb9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationAttributes.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -109,7 +110,7 @@ public enum VectorStoreObservationAttributes { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java index 9dc843e13..589b7ffcd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreObservationEventNames.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation.conventions; /** @@ -34,7 +35,7 @@ public enum VectorStoreObservationEventNames { } public String value() { - return value; + return this.value; } // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java index 2ceaf2f54..bd518c982 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java index ec60701a4..9a1ac00ca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreSimilarityMetric.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.observation.conventions; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java index 53f533019..34a4401d0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.observation.conventions; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java index 1ef4dfd32..023d5dc69 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java index fce9a93e8..a675fd98e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/tracing/TracingHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.observation.tracing; -import io.micrometer.tracing.handler.TracingObservationHandler; -import io.opentelemetry.api.trace.Span; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.observation.ChatModelObservationContext; -import org.springframework.ai.model.Content; -import org.springframework.lang.Nullable; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; +package org.springframework.ai.observation.tracing; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -31,6 +22,13 @@ import java.util.List; import java.util.Map; import java.util.StringJoiner; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.opentelemetry.api.trace.Span; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.lang.Nullable; + /** * Utilities to prepare and process traces for observability. * @@ -40,6 +38,9 @@ public final class TracingHelper { private static final Logger logger = LoggerFactory.getLogger(TracingHelper.class); + private TracingHelper() { + } + @Nullable public static Span extractOtelSpan(@Nullable TracingObservationHandler.TracingContext tracingContext) { if (tracingContext == null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java index a56714aef..9ba62979f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/EmptyJsonMetadataGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.util.Collections; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java index 03669f1b7..31112672c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/ExtractedTextFormatter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import org.springframework.util.StringUtils; @@ -32,19 +33,19 @@ import org.springframework.util.StringUtils; * * @author Christian Tzolov */ -public class ExtractedTextFormatter { +public final class ExtractedTextFormatter { /** Flag indicating if the text should be left-aligned */ - private boolean leftAlignment; + private final boolean leftAlignment; /** Number of top pages to skip before performing delete operations */ - private int numberOfTopPagesToSkipBeforeDelete; + private final int numberOfTopPagesToSkipBeforeDelete; /** Number of top text lines to delete from a page */ - private int numberOfTopTextLinesToDelete; + private final int numberOfTopTextLinesToDelete; /** Number of bottom text lines to delete from a page */ - private int numberOfBottomTextLinesToDelete; + private final int numberOfBottomTextLinesToDelete; /** * Private constructor to initialize the formatter from the builder. @@ -73,6 +74,82 @@ public class ExtractedTextFormatter { return new Builder().build(); } + /** + * Replaces multiple, adjacent blank lines into a single blank line. + * @param pageText text to adjust the blank lines for. + * @return Returns the same text but with blank lines trimmed. + */ + public static String trimAdjacentBlankLines(String pageText) { + return pageText.replaceAll("(?m)(^ *\n)", "\n").replaceAll("(?m)^$([\r\n]+?)(^$[\r\n]+?^)+", "$1"); + } + + /** + * @param pageText text to align. + * @return Returns the same text but aligned to the left side. + */ + public static String alignToLeft(String pageText) { + return pageText.replaceAll("(?m)(^ *| +(?= |$))", "").replaceAll("(?m)^$( ?)(^$[\r\n]+?^)+", "$1"); + } + + /** + * Removes the specified number of lines from the bottom part of the text. + * @param pageText Text to remove lines from. + * @param numberOfLines Number of lines to remove. + * @return Returns the text striped from last lines. + */ + public static String deleteBottomTextLines(String pageText, int numberOfLines) { + if (!StringUtils.hasText(pageText)) { + return pageText; + } + + int lineCount = 0; + int truncateIndex = pageText.length(); + int nextTruncateIndex = truncateIndex; + while (lineCount < numberOfLines && nextTruncateIndex >= 0) { + nextTruncateIndex = pageText.lastIndexOf(System.lineSeparator(), truncateIndex - 1); + truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; + lineCount++; + } + return pageText.substring(0, truncateIndex); + } + + /** + * Removes a specified number of lines from the top part of the given text. + * + *

+ * This method takes a text and trims it by removing a certain number of lines from + * the top. If the provided text is null or contains only whitespace, it will be + * returned as is. If the number of lines to remove exceeds the actual number of lines + * in the text, the result will be an empty string. + *

+ * + *

+ * The method identifies lines based on the system's line separator, making it + * compatible with different platforms. + *

+ * @param pageText The text from which the top lines need to be removed. If this is + * null, empty, or consists only of whitespace, it will be returned unchanged. + * @param numberOfLines The number of lines to remove from the top of the text. If + * this exceeds the actual number of lines in the text, an empty string will be + * returned. + * @return The text with the specified number of lines removed from the top. + */ + public static String deleteTopTextLines(String pageText, int numberOfLines) { + if (!StringUtils.hasText(pageText)) { + return pageText; + } + int lineCount = 0; + + int truncateIndex = 0; + int nextTruncateIndex = truncateIndex; + while (lineCount < numberOfLines && nextTruncateIndex >= 0) { + nextTruncateIndex = pageText.indexOf(System.lineSeparator(), truncateIndex + 1); + truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; + lineCount++; + } + return pageText.substring(truncateIndex); + } + /** * Formats the provided text according to the formatter's configuration. * @param pageText Text to be formatted. @@ -126,7 +203,7 @@ public class ExtractedTextFormatter { *
  • Number of top text lines to delete to 0
  • *
  • Number of bottom text lines to delete to 0
  • * - * + * * *

    * After configuring the builder, calling the {@link #build()} method will return a @@ -209,80 +286,4 @@ public class ExtractedTextFormatter { } - /** - * Replaces multiple, adjacent blank lines into a single blank line. - * @param pageText text to adjust the blank lines for. - * @return Returns the same text but with blank lines trimmed. - */ - public static String trimAdjacentBlankLines(String pageText) { - return pageText.replaceAll("(?m)(^ *\n)", "\n").replaceAll("(?m)^$([\r\n]+?)(^$[\r\n]+?^)+", "$1"); - } - - /** - * @param pageText text to align. - * @return Returns the same text but aligned to the left side. - */ - public static String alignToLeft(String pageText) { - return pageText.replaceAll("(?m)(^ *| +(?= |$))", "").replaceAll("(?m)^$( ?)(^$[\r\n]+?^)+", "$1"); - } - - /** - * Removes the specified number of lines from the bottom part of the text. - * @param pageText Text to remove lines from. - * @param numberOfLines Number of lines to remove. - * @return Returns the text striped from last lines. - */ - public static String deleteBottomTextLines(String pageText, int numberOfLines) { - if (!StringUtils.hasText(pageText)) { - return pageText; - } - - int lineCount = 0; - int truncateIndex = pageText.length(); - int nextTruncateIndex = truncateIndex; - while (lineCount < numberOfLines && nextTruncateIndex >= 0) { - nextTruncateIndex = pageText.lastIndexOf(System.lineSeparator(), truncateIndex - 1); - truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; - lineCount++; - } - return pageText.substring(0, truncateIndex); - } - - /** - * Removes a specified number of lines from the top part of the given text. - * - *

    - * This method takes a text and trims it by removing a certain number of lines from - * the top. If the provided text is null or contains only whitespace, it will be - * returned as is. If the number of lines to remove exceeds the actual number of lines - * in the text, the result will be an empty string. - *

    - * - *

    - * The method identifies lines based on the system's line separator, making it - * compatible with different platforms. - *

    - * @param pageText The text from which the top lines need to be removed. If this is - * null, empty, or consists only of whitespace, it will be returned unchanged. - * @param numberOfLines The number of lines to remove from the top of the text. If - * this exceeds the actual number of lines in the text, an empty string will be - * returned. - * @return The text with the specified number of lines removed from the top. - */ - public static String deleteTopTextLines(String pageText, int numberOfLines) { - if (!StringUtils.hasText(pageText)) { - return pageText; - } - int lineCount = 0; - - int truncateIndex = 0; - int nextTruncateIndex = truncateIndex; - while (lineCount < numberOfLines && nextTruncateIndex >= 0) { - nextTruncateIndex = pageText.indexOf(System.lineSeparator(), truncateIndex + 1); - truncateIndex = nextTruncateIndex < 0 ? truncateIndex : nextTruncateIndex; - lineCount++; - } - return pageText.substring(truncateIndex, pageText.length()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java index 4a4ffb1e9..a556e8b65 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonMetadataGenerator.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.util.Map; @@ -23,6 +24,8 @@ public interface JsonMetadataGenerator { /** * The input is the JSON document represented as a map, the output are the fields * extracted from the input map that will be used as metadata. + * @param jsonMap json document map + * @return json metadata map */ Map generate(Map jsonMap); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java index 7b4a2e8cc..2ea446400 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.io.IOException; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Collections; import java.util.stream.StreamSupport; import com.fasterxml.jackson.core.type.TypeReference; @@ -37,7 +37,7 @@ import org.springframework.core.io.Resource; * * @author Mark Pollack * @author Christian Tzolov - * @author rivkode + * @author rivkode rivkode * @since 1.0.0 */ public class JsonReader implements DocumentReader { @@ -73,15 +73,15 @@ public class JsonReader implements DocumentReader { @Override public List get() { try { - JsonNode rootNode = objectMapper.readTree(this.resource.getInputStream()); + JsonNode rootNode = this.objectMapper.readTree(this.resource.getInputStream()); if (rootNode.isArray()) { return StreamSupport.stream(rootNode.spliterator(), true) - .map(jsonNode -> parseJsonNode(jsonNode, objectMapper)) + .map(jsonNode -> parseJsonNode(jsonNode, this.objectMapper)) .toList(); } else { - return Collections.singletonList(parseJsonNode(rootNode, objectMapper)); + return Collections.singletonList(parseJsonNode(rootNode, this.objectMapper)); } } catch (IOException e) { @@ -91,12 +91,13 @@ public class JsonReader implements DocumentReader { private Document parseJsonNode(JsonNode jsonNode, ObjectMapper objectMapper) { Map item = objectMapper.convertValue(jsonNode, new TypeReference>() { + }); var sb = new StringBuilder(); - jsonKeysToUse.stream().filter(item::containsKey).forEach(key -> { - sb.append(key).append(": ").append(item.get(key)).append(System.lineSeparator()); - }); + this.jsonKeysToUse.stream() + .filter(item::containsKey) + .forEach(key -> sb.append(key).append(": ").append(item.get(key)).append(System.lineSeparator())); Map metadata = this.jsonMetadataGenerator.generate(item); String content = sb.isEmpty() ? item.toString() : sb.toString(); @@ -106,11 +107,11 @@ public class JsonReader implements DocumentReader { protected List get(JsonNode rootNode) { if (rootNode.isArray()) { return StreamSupport.stream(rootNode.spliterator(), true) - .map(jsonNode -> parseJsonNode(jsonNode, objectMapper)) + .map(jsonNode -> parseJsonNode(jsonNode, this.objectMapper)) .toList(); } else { - return Collections.singletonList(parseJsonNode(rootNode, objectMapper)); + return Collections.singletonList(parseJsonNode(rootNode, this.objectMapper)); } } @@ -122,7 +123,7 @@ public class JsonReader implements DocumentReader { */ public List get(String pointer) { try { - JsonNode rootNode = objectMapper.readTree(this.resource.getInputStream()); + JsonNode rootNode = this.objectMapper.readTree(this.resource.getInputStream()); JsonNode targetNode = rootNode.at(pointer); if (targetNode.isMissingNode()) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java index 55bb2e2bf..db9e284d0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/TextReader.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; import java.io.IOException; @@ -46,13 +47,13 @@ public class TextReader implements DocumentReader { */ private final Resource resource; + private final Map customMetadata = new HashMap<>(); + /** * Character set to be used when loading data from the */ private Charset charset = StandardCharsets.UTF_8; - private final Map customMetadata = new HashMap<>(); - public TextReader(String resourceUrl) { this(new DefaultResourceLoader().getResource(resourceUrl)); } @@ -62,15 +63,15 @@ public class TextReader implements DocumentReader { this.resource = resource; } + public Charset getCharset() { + return this.charset; + } + public void setCharset(Charset charset) { Objects.requireNonNull(charset, "The charset must not be null"); this.charset = charset; } - public Charset getCharset() { - return this.charset; - } - /** * Metadata associated with all documents created by the loader. * @return Metadata to be assigned to the output Documents. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java index 8a1dc60aa..760a9a4b3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,4 +87,4 @@ public class JTokkitTokenCountEstimator implements TokenCountEstimator { return totalSize; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java index 03a9eff5a..e33c464e9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tokenizer/TokenCountEstimator.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,4 +48,4 @@ public interface TokenCountEstimator { */ int estimate(Iterable messages); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java index 880abc735..32e201402 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/ContentFormatTransformer.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.ArrayList; @@ -67,7 +68,7 @@ public class ContentFormatTransformer implements DocumentTransformer { * @return processed documents */ public List apply(List documents) { - if (contentFormatter != null) { + if (this.contentFormatter != null) { documents.forEach(this::processDocument); } @@ -76,7 +77,7 @@ public class ContentFormatTransformer implements DocumentTransformer { private void processDocument(Document document) { if (document.getContentFormatter() instanceof DefaultContentFormatter docFormatter - && contentFormatter instanceof DefaultContentFormatter toUpdateFormatter) { + && this.contentFormatter instanceof DefaultContentFormatter toUpdateFormatter) { updateFormatter(document, docFormatter, toUpdateFormatter); } @@ -99,7 +100,7 @@ public class ContentFormatTransformer implements DocumentTransformer { .withMetadataTemplate(docFormatter.getMetadataTemplate()) .withMetadataSeparator(docFormatter.getMetadataSeparator()); - if (!disableTemplateRewrite) { + if (!this.disableTemplateRewrite) { builder.withTextTemplate(docFormatter.getTextTemplate()); } @@ -107,7 +108,7 @@ public class ContentFormatTransformer implements DocumentTransformer { } private void overrideFormatter(Document document) { - document.setContentFormatter(contentFormatter); + document.setContentFormatter(this.contentFormatter); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java index 10db9364d..dd02b336b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.List; import java.util.Map; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.document.Document; -import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentTransformer; import org.springframework.util.Assert; /** @@ -32,14 +33,14 @@ import org.springframework.util.Assert; */ public class KeywordMetadataEnricher implements DocumentTransformer { - private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; - public static final String CONTEXT_STR_PLACEHOLDER = "context_str"; public static final String KEYWORDS_TEMPLATE = """ {context_str}. Give %s unique keywords for this document. Format as comma separated. Keywords: """; + private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; + /** * Model predictor */ @@ -62,7 +63,7 @@ public class KeywordMetadataEnricher implements DocumentTransformer { public List apply(List documents) { for (Document document : documents) { - var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); + var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, this.keywordCount)); Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent())); String keywords = this.chatModel.call(prompt).getResult().getOutput().getContent(); document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index 60ecf450b..a15378282 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer; import java.util.ArrayList; @@ -21,11 +22,11 @@ import java.util.List; import java.util.Map; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -38,14 +39,6 @@ import org.springframework.util.CollectionUtils; */ public class SummaryMetadataEnricher implements DocumentTransformer { - private static final String SECTION_SUMMARY_METADATA_KEY = "section_summary"; - - private static final String NEXT_SECTION_SUMMARY_METADATA_KEY = "next_section_summary"; - - private static final String PREV_SECTION_SUMMARY_METADATA_KEY = "prev_section_summary"; - - private static final String CONTEXT_STR_PLACEHOLDER = "context_str"; - public static final String DEFAULT_SUMMARY_EXTRACT_TEMPLATE = """ Here is the content of the section: {context_str} @@ -54,11 +47,13 @@ public class SummaryMetadataEnricher implements DocumentTransformer { Summary:"""; - public enum SummaryType { + private static final String SECTION_SUMMARY_METADATA_KEY = "section_summary"; - PREVIOUS, CURRENT, NEXT + private static final String NEXT_SECTION_SUMMARY_METADATA_KEY = "next_section_summary"; - } + private static final String PREV_SECTION_SUMMARY_METADATA_KEY = "prev_section_summary"; + + private static final String CONTEXT_STR_PLACEHOLDER = "context_str"; /** * AI client. @@ -127,4 +122,10 @@ public class SummaryMetadataEnricher implements DocumentTransformer { return summaryMetadata; } + public enum SummaryType { + + PREVIOUS, CURRENT, NEXT + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java index 809fc556b..7d5439c86 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TextSplitter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; @@ -22,6 +23,7 @@ import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.ContentFormatter; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; @@ -49,14 +51,14 @@ public abstract class TextSplitter implements DocumentTransformer { return this.apply(List.of(document)); } - public void setCopyContentFormatter(boolean copyContentFormatter) { - this.copyContentFormatter = copyContentFormatter; - } - public boolean isCopyContentFormatter() { return this.copyContentFormatter; } + public void setCopyContentFormatter(boolean copyContentFormatter) { + this.copyContentFormatter = copyContentFormatter; + } + private List doSplitDocuments(List documents) { List texts = new ArrayList<>(); List> metadataList = new ArrayList<>(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index 420e7a287..4c1295436 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; @@ -33,10 +34,6 @@ import org.springframework.util.Assert; */ public class TokenTextSplitter extends TextSplitter { - private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); - - private final Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE); - private final static int DEFAULT_CHUNK_SIZE = 800; private final static int MIN_CHUNK_SIZE_CHARS = 350; @@ -47,6 +44,10 @@ public class TokenTextSplitter extends TextSplitter { private final static boolean KEEP_SEPARATOR = true; + private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); + + private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE); + // The target size of each text chunk in tokens private final int chunkSize; @@ -78,6 +79,10 @@ public class TokenTextSplitter extends TextSplitter { this.keepSeparator = keepSeparator; } + public static Builder builder() { + return new Builder(); + } + @Override protected List splitText(String text) { return doSplit(text, this.chunkSize); @@ -145,11 +150,7 @@ public class TokenTextSplitter extends TextSplitter { return this.encoding.decode(tokensIntArray); } - public static Builder builder() { - return new Builder(); - } - - public static class Builder { + public static final class Builder { private int chunkSize; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java index 80176631b..3686dd417 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/JacksonUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.util; import java.util.ArrayList; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java index 528016256..591aa0801 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/ParsingUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.util; import java.util.ArrayList; @@ -35,8 +36,8 @@ public abstract class ParsingUtils { private static final String LOWER = "\\p{Ll}"; - private static final String CAMEL_CASE_REGEX = "(?> typeRef = new TypeReference<>() { + }; try { Map deserializedMap = this.objectMapper.readValue(file, typeRef); @@ -193,6 +194,7 @@ public class SimpleVectorStore extends AbstractObservationVectorStore { */ public void load(Resource resource) { TypeReference> typeRef = new TypeReference<>() { + }; try { Map deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef); @@ -219,6 +221,15 @@ public class SimpleVectorStore extends AbstractObservationVectorStore { return this.embeddingModel.embed(query); } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.SIMPLE.value(), operationName) + .withDimensions(this.embeddingModel.dimensions()) + .withCollectionName("in-memory-map") + .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); + } + public static class Similarity { private String key; @@ -232,7 +243,7 @@ public class SimpleVectorStore extends AbstractObservationVectorStore { } - public class EmbeddingMath { + public final class EmbeddingMath { private EmbeddingMath() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); @@ -276,13 +287,4 @@ public class SimpleVectorStore extends AbstractObservationVectorStore { } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.SIMPLE.value(), operationName) - .withDimensions(this.embeddingModel.dimensions()) - .withCollectionName("in-memory-map") - .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java index dadcdda3b..1e8cf62dc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index e3036c075..53e16c691 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; /** @@ -64,32 +65,6 @@ package org.springframework.ai.vectorstore.filter; */ public class Filter { - /** - * Mark interface representing the supported expression types: {@link Key}, - * {@link Value}, {@link Expression} and {@link Group}. - */ - public interface Operand { - - } - - /** - * String identifier representing an expression key. (e.g. the country in the country - * == "NL" expression). - * - * @param key expression key - */ - public record Key(String key) implements Operand { - } - - /** - * Represents expression value constant or constant array. Support Numeric, Boolean - * and String data types. - * - * @param value value constant or constant array - */ - public record Value(Object value) implements Operand { - } - /** * Filter expression operations.
    * @@ -107,6 +82,34 @@ public class Filter { } + /** + * Mark interface representing the supported expression types: {@link Key}, + * {@link Value}, {@link Expression} and {@link Group}. + */ + public interface Operand { + + } + + /** + * String identifier representing an expression key. (e.g. the country in the country + * == "NL" expression). + * + * @param key expression key + */ + public record Key(String key) implements Operand { + + } + + /** + * Represents expression value constant or constant array. Support Numeric, Boolean + * and String data types. + * + * @param value value constant or constant array + */ + public record Value(Object value) implements Operand { + + } + /** * Triple that represents and filter boolean expression as * left type right. @@ -120,9 +123,11 @@ public class Filter { * be another {@link Expression}. */ public record Expression(ExpressionType type, Operand left, Operand right) implements Operand { + public Expression(ExpressionType type, Operand operand) { this(type, operand, null); } + } /** @@ -132,6 +137,7 @@ public class Filter { * @param content Inner expression to be evaluated as a part of the group. */ public record Group(Expression content) implements Operand { + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java index d3913c1d4..f7410c898 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -55,20 +56,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Value; */ public class FilterExpressionBuilder { - public record Op(Filter.Operand expression) { - - public Filter.Expression build() { - if (expression instanceof Filter.Group group) { - // Remove the top-level grouping. - return group.content(); - } - else if (expression instanceof Filter.Expression exp) { - return exp; - } - throw new RuntimeException("Invalid expression: " + expression); - } - } - public Op eq(String key, Object value) { return new Op(new Filter.Expression(ExpressionType.EQ, new Key(key), new Value(value))); } @@ -125,4 +112,19 @@ public class FilterExpressionBuilder { return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null)); } + public record Op(Filter.Operand expression) { + + public Filter.Expression build() { + if (this.expression instanceof Filter.Group group) { + // Remove the top-level grouping. + return group.content(); + } + else if (this.expression instanceof Filter.Expression exp) { + return exp; + } + throw new RuntimeException("Invalid expression: " + this.expression); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java index 127f2dc92..463da47fd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; /** @@ -24,6 +25,6 @@ package org.springframework.ai.vectorstore.filter; */ public interface FilterExpressionConverter { - public String convertExpression(Filter.Expression expression); + String convertExpression(Filter.Expression expression); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java index c4fd1a9d6..7d5e332c2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.ArrayList; @@ -162,7 +163,7 @@ public class FilterExpressionTextParser { /** For testing only */ Map getCache() { - return cache; + return this.cache; } public static class FilterExpressionParseException extends RuntimeException { @@ -217,9 +218,7 @@ public class FilterExpressionTextParser { @Override public Filter.Operand visitConstantArray(FiltersParser.ConstantArrayContext ctx) { List list = new ArrayList<>(); - ctx.constant().forEach(constantCtx -> { - list.add(((Filter.Value) this.visit(constantCtx)).value()); - }); + ctx.constant().forEach(constantCtx -> list.add(((Filter.Value) this.visit(constantCtx)).value())); return new Filter.Value(list); } @@ -301,4 +300,4 @@ public class FilterExpressionTextParser { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java index 555d2ca87..ce2bebf91 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.ArrayList; @@ -29,10 +30,7 @@ import org.springframework.util.Assert; * * @author Christian Tzolov */ -public class FilterHelper { - - private FilterHelper() { - } +public final class FilterHelper { private final static Map TYPE_NEGATION_MAP = Map.of(ExpressionType.AND, ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE, @@ -40,6 +38,9 @@ public class FilterHelper { ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT, ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN); + private FilterHelper() { + } + /** * Transforms the input expression into a semantically equivalent one with negation * operators propagated thought the expression tree by following the negation rules: diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java index 65bca7a3a..136aedbab 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +14,10 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ @@ -422,4 +408,4 @@ public class FiltersBaseListener implements FiltersListener { public void visitErrorNode(ErrorNode node) { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java index 99b0de1b6..f8a5a2041 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +14,10 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ @@ -244,4 +230,4 @@ public class FiltersBaseVisitor extends AbstractParseTreeVisitor implement return visitChildren(ctx); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java index 87fa8abfd..cf4e31c59 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersLexer.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,128 +14,40 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.Lexer; import org.antlr.v4.runtime.CharStream; -import org.antlr.v4.runtime.*; -import org.antlr.v4.runtime.atn.*; +import org.antlr.v4.runtime.Lexer; +import org.antlr.v4.runtime.RuntimeMetaData; +import org.antlr.v4.runtime.Vocabulary; +import org.antlr.v4.runtime.VocabularyImpl; +import org.antlr.v4.runtime.atn.ATN; +import org.antlr.v4.runtime.atn.ATNDeserializer; +import org.antlr.v4.runtime.atn.LexerATNSimulator; +import org.antlr.v4.runtime.atn.PredictionContextCache; import org.antlr.v4.runtime.dfa.DFA; @SuppressWarnings({ "all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape" }) public class FiltersLexer extends Lexer { - static { - RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); - } - - protected static final DFA[] _decisionToDFA; - - protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); - public static final int WHERE = 1, DOT = 2, COMMA = 3, LEFT_SQUARE_BRACKETS = 4, RIGHT_SQUARE_BRACKETS = 5, LEFT_PARENTHESIS = 6, RIGHT_PARENTHESIS = 7, EQUALS = 8, MINUS = 9, PLUS = 10, GT = 11, GE = 12, LT = 13, LE = 14, NE = 15, AND = 16, OR = 17, IN = 18, NIN = 19, NOT = 20, BOOLEAN_VALUE = 21, QUOTED_STRING = 22, INTEGER_VALUE = 23, DECIMAL_VALUE = 24, IDENTIFIER = 25, WS = 26; - public static String[] channelNames = { "DEFAULT_TOKEN_CHANNEL", "HIDDEN" }; - - public static String[] modeNames = { "DEFAULT_MODE" }; - - private static String[] makeRuleNames() { - return new String[] { "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "DECIMAL_DIGITS", "DIGIT", "LETTER", "WS" }; - } - public static final String[] ruleNames = makeRuleNames(); - private static String[] makeLiteralNames() { - return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", - "'<'", "'<='", "'!='" }; - } - - private static final String[] _LITERAL_NAMES = makeLiteralNames(); - - private static String[] makeSymbolicNames() { - return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "WS" }; - } - - private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); - - public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); - /** * @deprecated Use {@link #VOCABULARY} instead. */ @Deprecated public static final String[] tokenNames; - static { - tokenNames = new String[_SYMBOLIC_NAMES.length]; - for (int i = 0; i < tokenNames.length; i++) { - tokenNames[i] = VOCABULARY.getLiteralName(i); - if (tokenNames[i] == null) { - tokenNames[i] = VOCABULARY.getSymbolicName(i); - } - - if (tokenNames[i] == null) { - tokenNames[i] = ""; - } - } - } - - @Override - @Deprecated - public String[] getTokenNames() { - return tokenNames; - } - - @Override - - public Vocabulary getVocabulary() { - return VOCABULARY; - } - - public FiltersLexer(CharStream input) { - super(input); - _interp = new LexerATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); - } - - @Override - public String getGrammarFileName() { - return "Filters.g4"; - } - - @Override - public String[] getRuleNames() { - return ruleNames; - } - - @Override - public String getSerializedATN() { - return _serializedATN; - } - - @Override - public String[] getChannelNames() { - return channelNames; - } - - @Override - public String[] getModeNames() { - return modeNames; - } - - @Override - public ATN getATN() { - return _ATN; - } public static final String _serializedATN = "\u0004\u0000\u001a\u00e5\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002" + "\u0001\u0007\u0001\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002" @@ -307,6 +201,105 @@ public class FiltersLexer extends Lexer { + "\u00cf\u00d6\u00d8\u00e1\u0001\u0000\u0001\u0000"; public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); + + protected static final DFA[] _decisionToDFA; + + protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); + + private static final String[] _LITERAL_NAMES = makeLiteralNames(); + + private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); + + public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); + + public static String[] channelNames = { "DEFAULT_TOKEN_CHANNEL", "HIDDEN" }; + + public static String[] modeNames = { "DEFAULT_MODE" }; + + public FiltersLexer(CharStream input) { + super(input); + _interp = new LexerATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); + } + + private static String[] makeRuleNames() { + return new String[] { "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "DECIMAL_DIGITS", "DIGIT", "LETTER", "WS" }; + } + + private static String[] makeLiteralNames() { + return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", + "'<'", "'<='", "'!='" }; + } + + private static String[] makeSymbolicNames() { + return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "WS" }; + } + + @Override + @Deprecated + public String[] getTokenNames() { + return tokenNames; + } + + @Override + + public Vocabulary getVocabulary() { + return VOCABULARY; + } + + @Override + public String getGrammarFileName() { + return "Filters.g4"; + } + + @Override + public String[] getRuleNames() { + return ruleNames; + } + + @Override + public String getSerializedATN() { + return _serializedATN; + } + + @Override + public String[] getChannelNames() { + return channelNames; + } + + @Override + public String[] getModeNames() { + return modeNames; + } + + @Override + public ATN getATN() { + return _ATN; + } + + static { + RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); + } + + static { + tokenNames = new String[_SYMBOLIC_NAMES.length]; + for (int i = 0; i < tokenNames.length; i++) { + tokenNames[i] = VOCABULARY.getLiteralName(i); + if (tokenNames[i] == null) { + tokenNames[i] = VOCABULARY.getSymbolicName(i); + } + + if (tokenNames[i] == null) { + tokenNames[i] = ""; + } + } + } + static { _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { @@ -314,4 +307,4 @@ public class FiltersLexer extends Lexer { } } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java index c16c841b7..8e49aeff6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +14,10 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ @@ -246,4 +232,4 @@ public interface FiltersListener extends ParseTreeListener { */ void exitBooleanConstant(FiltersParser.BooleanConstantContext ctx); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java index ab66cff95..945a3a953 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,27 +14,39 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.atn.*; -import org.antlr.v4.runtime.dfa.DFA; -import org.antlr.v4.runtime.*; -import org.antlr.v4.runtime.tree.*; import java.util.List; +import org.antlr.v4.runtime.FailedPredicateException; +import org.antlr.v4.runtime.NoViableAltException; +import org.antlr.v4.runtime.Parser; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.RecognitionException; +import org.antlr.v4.runtime.RuleContext; +import org.antlr.v4.runtime.RuntimeMetaData; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.TokenStream; +import org.antlr.v4.runtime.Vocabulary; +import org.antlr.v4.runtime.VocabularyImpl; +import org.antlr.v4.runtime.atn.ATN; +import org.antlr.v4.runtime.atn.ATNDeserializer; +import org.antlr.v4.runtime.atn.ParserATNSimulator; +import org.antlr.v4.runtime.atn.PredictionContextCache; +import org.antlr.v4.runtime.dfa.DFA; +import org.antlr.v4.runtime.tree.ParseTreeListener; +import org.antlr.v4.runtime.tree.ParseTreeVisitor; +import org.antlr.v4.runtime.tree.TerminalNode; + @SuppressWarnings({ "all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue" }) public class FiltersParser extends Parser { - static { - RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); - } - - protected static final DFA[] _decisionToDFA; - - protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); - public static final int WHERE = 1, DOT = 2, COMMA = 3, LEFT_SQUARE_BRACKETS = 4, RIGHT_SQUARE_BRACKETS = 5, LEFT_PARENTHESIS = 6, RIGHT_PARENTHESIS = 7, EQUALS = 8, MINUS = 9, PLUS = 10, GT = 11, GE = 12, LT = 13, LE = 14, NE = 15, AND = 16, OR = 17, IN = 18, NIN = 19, NOT = 20, BOOLEAN_VALUE = 21, QUOTED_STRING = 22, @@ -61,35 +55,86 @@ public class FiltersParser extends Parser { public static final int RULE_where = 0, RULE_booleanExpression = 1, RULE_constantArray = 2, RULE_compare = 3, RULE_identifier = 4, RULE_constant = 5; - private static String[] makeRuleNames() { - return new String[] { "where", "booleanExpression", "constantArray", "compare", "identifier", "constant" }; - } - public static final String[] ruleNames = makeRuleNames(); - private static String[] makeLiteralNames() { - return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", - "'<'", "'<='", "'!='" }; - } - - private static final String[] _LITERAL_NAMES = makeLiteralNames(); - - private static String[] makeSymbolicNames() { - return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", - "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", - "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", - "IDENTIFIER", "WS" }; - } - - private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); - - public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); - /** * @deprecated Use {@link #VOCABULARY} instead. */ @Deprecated public static final String[] tokenNames; + + public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" + + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" + + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" + + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" + + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" + + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" + + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" + + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" + + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" + + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" + + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" + + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" + + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" + + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" + + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" + + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" + + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" + + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" + + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" + + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" + + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" + + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" + + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" + + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" + + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" + + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" + + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" + + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" + + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" + + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" + + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" + + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" + + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" + + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" + + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" + + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" + + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" + + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" + + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" + + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" + + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" + + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" + + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" + + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" + + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" + + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" + + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" + + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; + + public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); + + protected static final DFA[] _decisionToDFA; + + protected static final PredictionContextCache _sharedContextCache = new PredictionContextCache(); + + private static final String[] _LITERAL_NAMES = makeLiteralNames(); + + private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames(); + + public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES); + + static { + RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); + } + static { tokenNames = new String[_SYMBOLIC_NAMES.length]; for (int i = 0; i < tokenNames.length; i++) { @@ -104,6 +149,34 @@ public class FiltersParser extends Parser { } } + static { + _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; + for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { + _decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i); + } + } + + public FiltersParser(TokenStream input) { + super(input); + _interp = new ParserATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); + } + + private static String[] makeRuleNames() { + return new String[] { "where", "booleanExpression", "constantArray", "compare", "identifier", "constant" }; + } + + private static String[] makeLiteralNames() { + return new String[] { null, null, "'.'", "','", "'['", "']'", "'('", "')'", "'=='", "'-'", "'+'", "'>'", "'>='", + "'<'", "'<='", "'!='" }; + } + + private static String[] makeSymbolicNames() { + return new String[] { null, "WHERE", "DOT", "COMMA", "LEFT_SQUARE_BRACKETS", "RIGHT_SQUARE_BRACKETS", + "LEFT_PARENTHESIS", "RIGHT_PARENTHESIS", "EQUALS", "MINUS", "PLUS", "GT", "GE", "LT", "LE", "NE", "AND", + "OR", "IN", "NIN", "NOT", "BOOLEAN_VALUE", "QUOTED_STRING", "INTEGER_VALUE", "DECIMAL_VALUE", + "IDENTIFIER", "WS" }; + } + @Override @Deprecated public String[] getTokenNames() { @@ -136,57 +209,6 @@ public class FiltersParser extends Parser { return _ATN; } - public FiltersParser(TokenStream input) { - super(input); - _interp = new ParserATNSimulator(this, _ATN, _decisionToDFA, _sharedContextCache); - } - - @SuppressWarnings("CheckReturnValue") - public static class WhereContext extends ParserRuleContext { - - public TerminalNode WHERE() { - return getToken(FiltersParser.WHERE, 0); - } - - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); - } - - public TerminalNode EOF() { - return getToken(FiltersParser.EOF, 0); - } - - public WhereContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_where; - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterWhere(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitWhere(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitWhere(this); - else - return visitor.visitChildren(this); - } - - } - public final WhereContext where() throws RecognitionException { WhereContext _localctx = new WhereContext(_ctx, getState()); enterRule(_localctx, 0, RULE_where); @@ -212,330 +234,6 @@ public class FiltersParser extends Parser { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class BooleanExpressionContext extends ParserRuleContext { - - public BooleanExpressionContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_booleanExpression; - } - - public BooleanExpressionContext() { - } - - public void copyFrom(BooleanExpressionContext ctx) { - super.copyFrom(ctx); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class NinExpressionContext extends BooleanExpressionContext { - - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); - } - - public ConstantArrayContext constantArray() { - return getRuleContext(ConstantArrayContext.class, 0); - } - - public TerminalNode NOT() { - return getToken(FiltersParser.NOT, 0); - } - - public TerminalNode IN() { - return getToken(FiltersParser.IN, 0); - } - - public TerminalNode NIN() { - return getToken(FiltersParser.NIN, 0); - } - - public NinExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterNinExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitNinExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitNinExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class AndExpressionContext extends BooleanExpressionContext { - - public BooleanExpressionContext left; - - public Token operator; - - public BooleanExpressionContext right; - - public List booleanExpression() { - return getRuleContexts(BooleanExpressionContext.class); - } - - public BooleanExpressionContext booleanExpression(int i) { - return getRuleContext(BooleanExpressionContext.class, i); - } - - public TerminalNode AND() { - return getToken(FiltersParser.AND, 0); - } - - public AndExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterAndExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitAndExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitAndExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class InExpressionContext extends BooleanExpressionContext { - - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); - } - - public TerminalNode IN() { - return getToken(FiltersParser.IN, 0); - } - - public ConstantArrayContext constantArray() { - return getRuleContext(ConstantArrayContext.class, 0); - } - - public InExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterInExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitInExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitInExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class NotExpressionContext extends BooleanExpressionContext { - - public TerminalNode NOT() { - return getToken(FiltersParser.NOT, 0); - } - - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); - } - - public NotExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterNotExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitNotExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitNotExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class CompareExpressionContext extends BooleanExpressionContext { - - public IdentifierContext identifier() { - return getRuleContext(IdentifierContext.class, 0); - } - - public CompareContext compare() { - return getRuleContext(CompareContext.class, 0); - } - - public ConstantContext constant() { - return getRuleContext(ConstantContext.class, 0); - } - - public CompareExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterCompareExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitCompareExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitCompareExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class OrExpressionContext extends BooleanExpressionContext { - - public BooleanExpressionContext left; - - public Token operator; - - public BooleanExpressionContext right; - - public List booleanExpression() { - return getRuleContexts(BooleanExpressionContext.class); - } - - public BooleanExpressionContext booleanExpression(int i) { - return getRuleContext(BooleanExpressionContext.class, i); - } - - public TerminalNode OR() { - return getToken(FiltersParser.OR, 0); - } - - public OrExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterOrExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitOrExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitOrExpression(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class GroupExpressionContext extends BooleanExpressionContext { - - public TerminalNode LEFT_PARENTHESIS() { - return getToken(FiltersParser.LEFT_PARENTHESIS, 0); - } - - public BooleanExpressionContext booleanExpression() { - return getRuleContext(BooleanExpressionContext.class, 0); - } - - public TerminalNode RIGHT_PARENTHESIS() { - return getToken(FiltersParser.RIGHT_PARENTHESIS, 0); - } - - public GroupExpressionContext(BooleanExpressionContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterGroupExpression(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitGroupExpression(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitGroupExpression(this); - else - return visitor.visitChildren(this); - } - - } - public final BooleanExpressionContext booleanExpression() throws RecognitionException { return booleanExpression(0); } @@ -636,8 +334,9 @@ public class FiltersParser extends Parser { _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) { if (_alt == 1) { - if (_parseListeners != null) + if (_parseListeners != null) { triggerExitRuleEvent(); + } _prevctx = _localctx; { setState(47); @@ -649,8 +348,9 @@ public class FiltersParser extends Parser { ((AndExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); setState(41); - if (!(precpred(_ctx, 4))) + if (!(precpred(_ctx, 4))) { throw new FailedPredicateException(this, "precpred(_ctx, 4)"); + } setState(42); ((AndExpressionContext) _localctx).operator = match(AND); setState(43); @@ -663,8 +363,9 @@ public class FiltersParser extends Parser { ((OrExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); setState(44); - if (!(precpred(_ctx, 3))) + if (!(precpred(_ctx, 3))) { throw new FailedPredicateException(this, "precpred(_ctx, 3)"); + } setState(45); ((OrExpressionContext) _localctx).operator = match(OR); setState(46); @@ -691,64 +392,6 @@ public class FiltersParser extends Parser { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class ConstantArrayContext extends ParserRuleContext { - - public TerminalNode LEFT_SQUARE_BRACKETS() { - return getToken(FiltersParser.LEFT_SQUARE_BRACKETS, 0); - } - - public List constant() { - return getRuleContexts(ConstantContext.class); - } - - public ConstantContext constant(int i) { - return getRuleContext(ConstantContext.class, i); - } - - public TerminalNode RIGHT_SQUARE_BRACKETS() { - return getToken(FiltersParser.RIGHT_SQUARE_BRACKETS, 0); - } - - public List COMMA() { - return getTokens(FiltersParser.COMMA); - } - - public TerminalNode COMMA(int i) { - return getToken(FiltersParser.COMMA, i); - } - - public ConstantArrayContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_constantArray; - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterConstantArray(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitConstantArray(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitConstantArray(this); - else - return visitor.visitChildren(this); - } - - } - public final ConstantArrayContext constantArray() throws RecognitionException { ConstantArrayContext _localctx = new ConstantArrayContext(_ctx, getState()); enterRule(_localctx, 4, RULE_constantArray); @@ -791,64 +434,6 @@ public class FiltersParser extends Parser { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class CompareContext extends ParserRuleContext { - - public TerminalNode EQUALS() { - return getToken(FiltersParser.EQUALS, 0); - } - - public TerminalNode GT() { - return getToken(FiltersParser.GT, 0); - } - - public TerminalNode GE() { - return getToken(FiltersParser.GE, 0); - } - - public TerminalNode LT() { - return getToken(FiltersParser.LT, 0); - } - - public TerminalNode LE() { - return getToken(FiltersParser.LE, 0); - } - - public TerminalNode NE() { - return getToken(FiltersParser.NE, 0); - } - - public CompareContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_compare; - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterCompare(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitCompare(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitCompare(this); - else - return visitor.visitChildren(this); - } - - } - public final CompareContext compare() throws RecognitionException { CompareContext _localctx = new CompareContext(_ctx, getState()); enterRule(_localctx, 6, RULE_compare); @@ -862,8 +447,9 @@ public class FiltersParser extends Parser { _errHandler.recoverInline(this); } else { - if (_input.LA(1) == Token.EOF) + if (_input.LA(1) == Token.EOF) { matchedEOF = true; + } _errHandler.reportMatch(this); consume(); } @@ -880,56 +466,6 @@ public class FiltersParser extends Parser { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class IdentifierContext extends ParserRuleContext { - - public List IDENTIFIER() { - return getTokens(FiltersParser.IDENTIFIER); - } - - public TerminalNode IDENTIFIER(int i) { - return getToken(FiltersParser.IDENTIFIER, i); - } - - public TerminalNode DOT() { - return getToken(FiltersParser.DOT, 0); - } - - public TerminalNode QUOTED_STRING() { - return getToken(FiltersParser.QUOTED_STRING, 0); - } - - public IdentifierContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_identifier; - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterIdentifier(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitIdentifier(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitIdentifier(this); - else - return visitor.visitChildren(this); - } - - } - public final IdentifierContext identifier() throws RecognitionException { IdentifierContext _localctx = new IdentifierContext(_ctx, getState()); enterRule(_localctx, 8, RULE_identifier); @@ -972,179 +508,6 @@ public class FiltersParser extends Parser { return _localctx; } - @SuppressWarnings("CheckReturnValue") - public static class ConstantContext extends ParserRuleContext { - - public ConstantContext(ParserRuleContext parent, int invokingState) { - super(parent, invokingState); - } - - @Override - public int getRuleIndex() { - return RULE_constant; - } - - public ConstantContext() { - } - - public void copyFrom(ConstantContext ctx) { - super.copyFrom(ctx); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class DecimalConstantContext extends ConstantContext { - - public TerminalNode DECIMAL_VALUE() { - return getToken(FiltersParser.DECIMAL_VALUE, 0); - } - - public TerminalNode MINUS() { - return getToken(FiltersParser.MINUS, 0); - } - - public TerminalNode PLUS() { - return getToken(FiltersParser.PLUS, 0); - } - - public DecimalConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterDecimalConstant(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitDecimalConstant(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitDecimalConstant(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class TextConstantContext extends ConstantContext { - - public List QUOTED_STRING() { - return getTokens(FiltersParser.QUOTED_STRING); - } - - public TerminalNode QUOTED_STRING(int i) { - return getToken(FiltersParser.QUOTED_STRING, i); - } - - public TextConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterTextConstant(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitTextConstant(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitTextConstant(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class BooleanConstantContext extends ConstantContext { - - public TerminalNode BOOLEAN_VALUE() { - return getToken(FiltersParser.BOOLEAN_VALUE, 0); - } - - public BooleanConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterBooleanConstant(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitBooleanConstant(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitBooleanConstant(this); - else - return visitor.visitChildren(this); - } - - } - - @SuppressWarnings("CheckReturnValue") - public static class IntegerConstantContext extends ConstantContext { - - public TerminalNode INTEGER_VALUE() { - return getToken(FiltersParser.INTEGER_VALUE, 0); - } - - public TerminalNode MINUS() { - return getToken(FiltersParser.MINUS, 0); - } - - public TerminalNode PLUS() { - return getToken(FiltersParser.PLUS, 0); - } - - public IntegerConstantContext(ConstantContext ctx) { - copyFrom(ctx); - } - - @Override - public void enterRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).enterIntegerConstant(this); - } - - @Override - public void exitRule(ParseTreeListener listener) { - if (listener instanceof FiltersListener) - ((FiltersListener) listener).exitIntegerConstant(this); - } - - @Override - public T accept(ParseTreeVisitor visitor) { - if (visitor instanceof FiltersVisitor) - return ((FiltersVisitor) visitor).visitIntegerConstant(this); - else - return visitor.visitChildren(this); - } - - } - public final ConstantContext constant() throws RecognitionException { ConstantContext _localctx = new ConstantContext(_ctx, getState()); enterRule(_localctx, 10, RULE_constant); @@ -1168,8 +531,9 @@ public class FiltersParser extends Parser { _errHandler.recoverInline(this); } else { - if (_input.LA(1) == Token.EOF) + if (_input.LA(1) == Token.EOF) { matchedEOF = true; + } _errHandler.reportMatch(this); consume(); } @@ -1194,8 +558,9 @@ public class FiltersParser extends Parser { _errHandler.recoverInline(this); } else { - if (_input.LA(1) == Token.EOF) + if (_input.LA(1) == Token.EOF) { matchedEOF = true; + } _errHandler.reportMatch(this); consume(); } @@ -1269,68 +634,773 @@ public class FiltersParser extends Parser { return true; } - public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" - + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" - + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" - + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" - + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" - + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" - + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" - + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" - + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" - + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" - + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" - + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" - + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" - + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" - + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" - + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" - + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" - + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" - + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" - + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" - + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" - + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" - + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" - + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" - + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" - + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" - + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" - + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" - + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" - + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" - + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" - + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" - + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" - + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" - + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" - + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" - + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" - + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" - + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" - + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" - + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" - + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" - + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" - + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" - + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" - + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" - + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" - + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" - + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; + @SuppressWarnings("CheckReturnValue") + public static class WhereContext extends ParserRuleContext { - public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); - static { - _decisionToDFA = new DFA[_ATN.getNumberOfDecisions()]; - for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) { - _decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i); + public WhereContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); } + + public TerminalNode WHERE() { + return getToken(FiltersParser.WHERE, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + public TerminalNode EOF() { + return getToken(FiltersParser.EOF, 0); + } + + @Override + public int getRuleIndex() { + return RULE_where; + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterWhere(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitWhere(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitWhere(this); + } + else { + return visitor.visitChildren(this); + } + } + } -} \ No newline at end of file + @SuppressWarnings("CheckReturnValue") + public static class BooleanExpressionContext extends ParserRuleContext { + + public BooleanExpressionContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + + public BooleanExpressionContext() { + } + + @Override + public int getRuleIndex() { + return RULE_booleanExpression; + } + + public void copyFrom(BooleanExpressionContext ctx) { + super.copyFrom(ctx); + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class NinExpressionContext extends BooleanExpressionContext { + + public NinExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); + } + + public ConstantArrayContext constantArray() { + return getRuleContext(ConstantArrayContext.class, 0); + } + + public TerminalNode NOT() { + return getToken(FiltersParser.NOT, 0); + } + + public TerminalNode IN() { + return getToken(FiltersParser.IN, 0); + } + + public TerminalNode NIN() { + return getToken(FiltersParser.NIN, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterNinExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitNinExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitNinExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class AndExpressionContext extends BooleanExpressionContext { + + public BooleanExpressionContext left; + + public Token operator; + + public BooleanExpressionContext right; + + public AndExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public List booleanExpression() { + return getRuleContexts(BooleanExpressionContext.class); + } + + public BooleanExpressionContext booleanExpression(int i) { + return getRuleContext(BooleanExpressionContext.class, i); + } + + public TerminalNode AND() { + return getToken(FiltersParser.AND, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterAndExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitAndExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitAndExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class InExpressionContext extends BooleanExpressionContext { + + public InExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); + } + + public TerminalNode IN() { + return getToken(FiltersParser.IN, 0); + } + + public ConstantArrayContext constantArray() { + return getRuleContext(ConstantArrayContext.class, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterInExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitInExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitInExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class NotExpressionContext extends BooleanExpressionContext { + + public NotExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public TerminalNode NOT() { + return getToken(FiltersParser.NOT, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterNotExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitNotExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitNotExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class CompareExpressionContext extends BooleanExpressionContext { + + public CompareExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public IdentifierContext identifier() { + return getRuleContext(IdentifierContext.class, 0); + } + + public CompareContext compare() { + return getRuleContext(CompareContext.class, 0); + } + + public ConstantContext constant() { + return getRuleContext(ConstantContext.class, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterCompareExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitCompareExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitCompareExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class OrExpressionContext extends BooleanExpressionContext { + + public BooleanExpressionContext left; + + public Token operator; + + public BooleanExpressionContext right; + + public OrExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public List booleanExpression() { + return getRuleContexts(BooleanExpressionContext.class); + } + + public BooleanExpressionContext booleanExpression(int i) { + return getRuleContext(BooleanExpressionContext.class, i); + } + + public TerminalNode OR() { + return getToken(FiltersParser.OR, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterOrExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitOrExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitOrExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class GroupExpressionContext extends BooleanExpressionContext { + + public GroupExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + public TerminalNode LEFT_PARENTHESIS() { + return getToken(FiltersParser.LEFT_PARENTHESIS, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + public TerminalNode RIGHT_PARENTHESIS() { + return getToken(FiltersParser.RIGHT_PARENTHESIS, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterGroupExpression(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitGroupExpression(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitGroupExpression(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class ConstantArrayContext extends ParserRuleContext { + + public ConstantArrayContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + + public TerminalNode LEFT_SQUARE_BRACKETS() { + return getToken(FiltersParser.LEFT_SQUARE_BRACKETS, 0); + } + + public List constant() { + return getRuleContexts(ConstantContext.class); + } + + public ConstantContext constant(int i) { + return getRuleContext(ConstantContext.class, i); + } + + public TerminalNode RIGHT_SQUARE_BRACKETS() { + return getToken(FiltersParser.RIGHT_SQUARE_BRACKETS, 0); + } + + public List COMMA() { + return getTokens(FiltersParser.COMMA); + } + + public TerminalNode COMMA(int i) { + return getToken(FiltersParser.COMMA, i); + } + + @Override + public int getRuleIndex() { + return RULE_constantArray; + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterConstantArray(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitConstantArray(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitConstantArray(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class CompareContext extends ParserRuleContext { + + public CompareContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + + public TerminalNode EQUALS() { + return getToken(FiltersParser.EQUALS, 0); + } + + public TerminalNode GT() { + return getToken(FiltersParser.GT, 0); + } + + public TerminalNode GE() { + return getToken(FiltersParser.GE, 0); + } + + public TerminalNode LT() { + return getToken(FiltersParser.LT, 0); + } + + public TerminalNode LE() { + return getToken(FiltersParser.LE, 0); + } + + public TerminalNode NE() { + return getToken(FiltersParser.NE, 0); + } + + @Override + public int getRuleIndex() { + return RULE_compare; + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterCompare(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitCompare(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitCompare(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class IdentifierContext extends ParserRuleContext { + + public IdentifierContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + + public List IDENTIFIER() { + return getTokens(FiltersParser.IDENTIFIER); + } + + public TerminalNode IDENTIFIER(int i) { + return getToken(FiltersParser.IDENTIFIER, i); + } + + public TerminalNode DOT() { + return getToken(FiltersParser.DOT, 0); + } + + public TerminalNode QUOTED_STRING() { + return getToken(FiltersParser.QUOTED_STRING, 0); + } + + @Override + public int getRuleIndex() { + return RULE_identifier; + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterIdentifier(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitIdentifier(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitIdentifier(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class ConstantContext extends ParserRuleContext { + + public ConstantContext(ParserRuleContext parent, int invokingState) { + super(parent, invokingState); + } + + public ConstantContext() { + } + + @Override + public int getRuleIndex() { + return RULE_constant; + } + + public void copyFrom(ConstantContext ctx) { + super.copyFrom(ctx); + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class DecimalConstantContext extends ConstantContext { + + public DecimalConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + + public TerminalNode DECIMAL_VALUE() { + return getToken(FiltersParser.DECIMAL_VALUE, 0); + } + + public TerminalNode MINUS() { + return getToken(FiltersParser.MINUS, 0); + } + + public TerminalNode PLUS() { + return getToken(FiltersParser.PLUS, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterDecimalConstant(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitDecimalConstant(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitDecimalConstant(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class TextConstantContext extends ConstantContext { + + public TextConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + + public List QUOTED_STRING() { + return getTokens(FiltersParser.QUOTED_STRING); + } + + public TerminalNode QUOTED_STRING(int i) { + return getToken(FiltersParser.QUOTED_STRING, i); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterTextConstant(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitTextConstant(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitTextConstant(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class BooleanConstantContext extends ConstantContext { + + public BooleanConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + + public TerminalNode BOOLEAN_VALUE() { + return getToken(FiltersParser.BOOLEAN_VALUE, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterBooleanConstant(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitBooleanConstant(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitBooleanConstant(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + + @SuppressWarnings("CheckReturnValue") + public static class IntegerConstantContext extends ConstantContext { + + public IntegerConstantContext(ConstantContext ctx) { + copyFrom(ctx); + } + + public TerminalNode INTEGER_VALUE() { + return getToken(FiltersParser.INTEGER_VALUE, 0); + } + + public TerminalNode MINUS() { + return getToken(FiltersParser.MINUS, 0); + } + + public TerminalNode PLUS() { + return getToken(FiltersParser.PLUS, 0); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).enterIntegerConstant(this); + } + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) { + ((FiltersListener) listener).exitIntegerConstant(this); + } + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) { + return ((FiltersVisitor) visitor).visitIntegerConstant(this); + } + else { + return visitor.visitChildren(this); + } + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java index 3f099b182..887159c2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java @@ -1,23 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 -package org.springframework.ai.vectorstore.filter.antlr4; - -/* - * Copyright 2023-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,12 +14,16 @@ package org.springframework.ai.vectorstore.filter.antlr4; * limitations under the License. */ +package org.springframework.ai.vectorstore.filter.antlr4; + +// Generated from org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 by ANTLR 4.13.1 + +import org.antlr.v4.runtime.tree.ParseTreeVisitor; + // ############################################################ // # NOTE: This is ANTLR4 auto-generated code. Do not modify! # // ############################################################ -import org.antlr.v4.runtime.tree.ParseTreeVisitor; - /** * This interface defines a complete generic visitor for a parse tree produced by * {@link FiltersParser}. @@ -163,4 +149,4 @@ public interface FiltersVisitor extends ParseTreeVisitor { */ T visitBooleanConstant(FiltersParser.BooleanConstantContext ctx); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java index b3fbeda67..808e790e8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import java.util.List; import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Operand; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.filter.FilterHelper; /** * AbstractFilterExpressionConverter is an abstract class that implements the diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java index 4f8c6c061..64877fc24 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import org.springframework.ai.vectorstore.filter.Filter.Expression; @@ -60,4 +61,4 @@ public class PineconeFilterExpressionConverter extends AbstractFilterExpressionC context.append("\"" + identifier + "\": "); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java index b2e93fcf7..14d2d1216 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/PrintFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import org.springframework.ai.vectorstore.filter.Filter.Expression; @@ -47,4 +48,4 @@ public class PrintFilterExpressionConverter extends AbstractFilterExpressionConv context.append(")"); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java index ee4e98a6b..025f8e600 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java @@ -1,30 +1,31 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import java.util.List; import java.util.Optional; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.lang.Nullable; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @since 1.0.0 @@ -53,7 +54,7 @@ public abstract class AbstractObservationVectorStore implements VectorStore { VectorStoreObservationDocumentation.AI_VECTOR_STORE .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - observationRegistry) + this.observationRegistry) .observe(() -> this.doAdd(documents)); } @@ -96,4 +97,4 @@ public abstract class AbstractObservationVectorStore implements VectorStore { public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName); -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java index cfddd211c..15700d357 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConvention.java @@ -1,29 +1,30 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; +import io.micrometer.common.KeyValue; +import io.micrometer.common.KeyValues; + import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import io.micrometer.common.KeyValue; -import io.micrometer.common.KeyValues; - /** * Default conventions to populate observations for vector store operations. * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java index cb834f586..8513f81d6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContentProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import java.util.List; + import org.springframework.ai.document.Document; import org.springframework.util.CollectionUtils; -import java.util.List; - /** * Utilities to process the query content in observations for vector store operations. * @@ -27,6 +28,9 @@ import java.util.List; */ public final class VectorStoreObservationContentProcessor { + private VectorStoreObservationContentProcessor() { + } + public static List documents(VectorStoreObservationContext context) { if (CollectionUtils.isEmpty(context.getQueryResponse())) { return List.of(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java index d12dd55ad..07da0990d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContext.java @@ -1,29 +1,30 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.observation.Observation; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import io.micrometer.observation.Observation; - /** * Context used to store metadata for vector store operations. * @@ -33,6 +34,121 @@ import io.micrometer.observation.Observation; */ public class VectorStoreObservationContext extends Observation.Context { + private final String databaseSystem; + + // COMMON + + private final String operationName; + + @Nullable + private String collectionName; + + @Nullable + private Integer dimensions; + + @Nullable + private String fieldName; + + @Nullable + private String namespace; + + @Nullable + private String similarityMetric; + + @Nullable + private SearchRequest queryRequest; + + // SEARCH + + @Nullable + private List queryResponse; + + public VectorStoreObservationContext(String databaseSystem, String operationName) { + Assert.hasText(databaseSystem, "databaseSystem cannot be null or empty"); + Assert.hasText(operationName, "operationName cannot be null or empty"); + this.databaseSystem = databaseSystem; + this.operationName = operationName; + } + + public static Builder builder(String databaseSystem, String operationName) { + return new Builder(databaseSystem, operationName); + } + + public static Builder builder(String databaseSystem, Operation operation) { + return builder(databaseSystem, operation.value); + } + + public String getDatabaseSystem() { + return this.databaseSystem; + } + + public String getOperationName() { + return this.operationName; + } + + @Nullable + public String getCollectionName() { + return this.collectionName; + } + + public void setCollectionName(@Nullable String collectionName) { + this.collectionName = collectionName; + } + + @Nullable + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(@Nullable Integer dimensions) { + this.dimensions = dimensions; + } + + @Nullable + public String getFieldName() { + return this.fieldName; + } + + public void setFieldName(@Nullable String fieldName) { + this.fieldName = fieldName; + } + + @Nullable + public String getNamespace() { + return this.namespace; + } + + public void setNamespace(@Nullable String namespace) { + this.namespace = namespace; + } + + @Nullable + public String getSimilarityMetric() { + return this.similarityMetric; + } + + public void setSimilarityMetric(@Nullable String similarityMetric) { + this.similarityMetric = similarityMetric; + } + + @Nullable + public SearchRequest getQueryRequest() { + return this.queryRequest; + } + + public void setQueryRequest(@Nullable SearchRequest queryRequest) { + this.queryRequest = queryRequest; + } + + @Nullable + public List getQueryResponse() { + return this.queryResponse; + } + + public void setQueryResponse(@Nullable List queryResponse) { + this.queryResponse = queryResponse; + } + public enum Operation { /** @@ -60,121 +176,6 @@ public class VectorStoreObservationContext extends Observation.Context { } - // COMMON - - private final String databaseSystem; - - private final String operationName; - - @Nullable - private String collectionName; - - @Nullable - private Integer dimensions; - - @Nullable - private String fieldName; - - @Nullable - private String namespace; - - @Nullable - private String similarityMetric; - - // SEARCH - - @Nullable - private SearchRequest queryRequest; - - @Nullable - private List queryResponse; - - public VectorStoreObservationContext(String databaseSystem, String operationName) { - Assert.hasText(databaseSystem, "databaseSystem cannot be null or empty"); - Assert.hasText(operationName, "operationName cannot be null or empty"); - this.databaseSystem = databaseSystem; - this.operationName = operationName; - } - - public String getDatabaseSystem() { - return this.databaseSystem; - } - - public String getOperationName() { - return this.operationName; - } - - @Nullable - public String getCollectionName() { - return collectionName; - } - - public void setCollectionName(@Nullable String collectionName) { - this.collectionName = collectionName; - } - - @Nullable - public Integer getDimensions() { - return dimensions; - } - - public void setDimensions(@Nullable Integer dimensions) { - this.dimensions = dimensions; - } - - @Nullable - public String getFieldName() { - return fieldName; - } - - public void setFieldName(@Nullable String fieldName) { - this.fieldName = fieldName; - } - - @Nullable - public String getNamespace() { - return namespace; - } - - public void setNamespace(@Nullable String namespace) { - this.namespace = namespace; - } - - @Nullable - public String getSimilarityMetric() { - return similarityMetric; - } - - public void setSimilarityMetric(@Nullable String similarityMetric) { - this.similarityMetric = similarityMetric; - } - - @Nullable - public SearchRequest getQueryRequest() { - return queryRequest; - } - - public void setQueryRequest(@Nullable SearchRequest queryRequest) { - this.queryRequest = queryRequest; - } - - @Nullable - public List getQueryResponse() { - return queryResponse; - } - - public void setQueryResponse(@Nullable List queryResponse) { - this.queryResponse = queryResponse; - } - - public static Builder builder(String databaseSystem, String operationName) { - return new Builder(databaseSystem, operationName); - } - - public static Builder builder(String databaseSystem, Operation operation) { - return builder(databaseSystem, operation.value); - } - public static class Builder { private final VectorStoreObservationContext context; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java index 9bf80d838..38a64d377 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationConvention.java @@ -1,18 +1,19 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; @@ -30,4 +31,4 @@ public interface VectorStoreObservationConvention extends ObservationConvention< return context instanceof VectorStoreObservationContext; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java index f56ead4de..f351ca292 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationDocumentation.java @@ -1,27 +1,28 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.vectorstore.observation; + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; +package org.springframework.ai.vectorstore.observation; import io.micrometer.common.docs.KeyName; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationConvention; import io.micrometer.observation.docs.ObservationDocumentation; +import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; + /** * Documented conventions for vector store observations. * @@ -85,7 +86,7 @@ public enum VectorStoreObservationDocumentation implements ObservationDocumentat public String asString() { return VectorStoreObservationAttributes.DB_SYSTEM.value(); } - }; + } } @@ -200,7 +201,7 @@ public enum VectorStoreObservationDocumentation implements ObservationDocumentat public String asString() { return VectorStoreObservationAttributes.DB_VECTOR_QUERY_TOP_K.value(); } - }; + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java index 4beab2b4f..a601acc3b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import org.springframework.ai.observation.tracing.TracingHelper; +package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationFilter; + +import org.springframework.ai.observation.tracing.TracingHelper; import org.springframework.util.CollectionUtils; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java index 1e46710d7..9dbbefc8c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; import io.micrometer.observation.Observation; @@ -21,6 +22,7 @@ import io.micrometer.tracing.handler.TracingObservationHandler; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; + import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; import org.springframework.ai.observation.conventions.VectorStoreObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java index a7e006093..0fd62c25b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,4 +19,4 @@ package org.springframework.ai.vectorstore.observation; import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; \ No newline at end of file +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java b/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java index 023cfa6a1..8971ef22d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/writer/FileDocumentWriter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.writer; import java.io.FileWriter; diff --git a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties index b6fb61c96..85a5447d3 100644 --- a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties +++ b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties @@ -1,3 +1,18 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Map of embedding generative names and their dimensions text-embedding-ada-002=1536 text-similarity-ada-001=1024 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java b/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java index 203ac200f..582552998 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/TestConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.springframework.boot.SpringBootConfiguration; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java index 02addaecf..97df43159 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/AiRuntimeHintsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import java.util.Set; @@ -28,25 +29,6 @@ import org.springframework.util.Assert; class AiRuntimeHintsTests { - @JsonInclude - static class TestApi { - - static class FooBar { - - } - - record Foo(@JsonProperty("name") String name) { - } - - @JsonInclude - enum Bar { - - A, B - - } - - } - @Test void discoverRelevantClasses() throws Exception { var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class); @@ -58,4 +40,24 @@ class AiRuntimeHintsTests { Assert.state(classes.containsAll(included), "there should be all of the enumerated classes. "); } + @JsonInclude + static class TestApi { + + @JsonInclude + enum Bar { + + A, B + + } + + static class FooBar { + + } + + record Foo(@JsonProperty("name") String name) { + + } + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java index eb45821b8..409c2329e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/KnuddelsRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.junit.jupiter.api.Test; + import org.springframework.aot.hint.RuntimeHints; import static org.assertj.core.api.Assertions.assertThat; @@ -31,4 +33,4 @@ class KnuddelsRuntimeHintsTest { assertThat(runtimeHints).matches(resource().forResource("com/knuddels/jtokkit/cl100k_base.tiktoken")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java index 372d6f076..b379c9bb1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/SpringAiCoreRuntimeHintsTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.aot; import org.junit.jupiter.api.Test; @@ -38,4 +39,4 @@ class SpringAiCoreRuntimeHintsTest { assertThat(runtimeHints).matches(reflection().onMethod(FunctionCallback.class, "getName")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java index 5ed9ccc8f..7ab6abb25 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; + import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; @@ -30,6 +30,8 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.model.function.FunctionCallingOptions; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link Prompt}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java index f3be7deaf..27568b6be 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,28 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import org.junit.jupiter.api.Test; - -import org.mockito.Mockito; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; /** * Unit Tests for {@link ChatModel}. @@ -53,15 +53,15 @@ class ChatModelTests { ChatModel mockClient = Mockito.mock(ChatModel.class); AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); - when(mockAssistantMessage.getContent()).thenReturn(responseMessage); + given(mockAssistantMessage.getContent()).willReturn(responseMessage); // Create a mock Generation Generation generation = Mockito.mock(Generation.class); - when(generation.getOutput()).thenReturn(mockAssistantMessage); + given(generation.getOutput()).willReturn(mockAssistantMessage); // Create a mock ChatResponse with the mock Generation ChatResponse response = Mockito.mock(ChatResponse.class); - when(response.getResult()).thenReturn(generation); + given(response.getResult()).willReturn(generation); // Generation generation = spy(new Generation(responseMessage)); // ChatResponse response = spy(new @@ -69,16 +69,14 @@ class ChatModelTests { doCallRealMethod().when(mockClient).call(anyString()); - doAnswer(invocationOnMock -> { - + given(mockClient.call(any(Prompt.class))).willAnswer(invocationOnMock -> { Prompt prompt = invocationOnMock.getArgument(0); assertThat(prompt).isNotNull(); assertThat(prompt.getContents()).isEqualTo(userMessage); return response; - - }).when(mockClient).call(any(Prompt.class)); + }); assertThat(mockClient.call(userMessage)).isEqualTo(responseMessage); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java index 251dd184e..07d77ecb1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.stream.Collectors; @@ -28,6 +25,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemory; @@ -35,14 +34,14 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -71,15 +70,15 @@ public class ChatClientAdvisorTests { .withKeyValue("system-fingerprint", "john doe"); ChatResponseMetadata chatResponseMetadata = builder.build(); - when(chatModel.call(promptCaptor.capture())) - .thenReturn( + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn( new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))), chatResponseMetadata)) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))), + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))), chatResponseMetadata)); ChatMemory chatMemory = new InMemoryChatMemory(); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) .build(); @@ -89,7 +88,7 @@ public class ChatClientAdvisorTests { String content = chatResponse.getResult().getOutput().getContent(); assertThat(content).isEqualTo("Hello John"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -101,14 +100,14 @@ public class ChatClientAdvisorTests { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); content = chatClient.prompt().user("What is my name?").call().content(); assertThat(content).isEqualTo("Your name is John"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -122,20 +121,20 @@ public class ChatClientAdvisorTests { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - userMessage = promptCaptor.getValue().getInstructions().get(1); + userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("What is my name?"); } @Test public void streamingPromptChatMemory() { - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })) - .thenReturn(Flux.generate( + .willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John")))), (state, sink) -> { sink.next(state); @@ -145,7 +144,7 @@ public class ChatClientAdvisorTests { ChatMemory chatMemory = new InMemoryChatMemory(); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) .build(); @@ -154,7 +153,7 @@ public class ChatClientAdvisorTests { assertThat(content).isEqualTo("Hello John"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -166,14 +165,14 @@ public class ChatClientAdvisorTests { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); content = join(chatClient.prompt().user("What is my name?").stream().content()); assertThat(content).isEqualTo("Your name is John"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. @@ -187,7 +186,7 @@ public class ChatClientAdvisorTests { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - userMessage = promptCaptor.getValue().getInstructions().get(1); + userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("What is my name?"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java index 2e40f9def..5295a213a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ import org.springframework.ai.converter.MapOutputConverter; import org.springframework.core.ParameterizedTypeReference; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -51,9 +51,6 @@ public class ChatClientResponseEntityTests { @Captor ArgumentCaptor promptCaptor; - record MyBean(String name, int age) { - } - @Test public void responseEntityTest() { @@ -63,9 +60,9 @@ public class ChatClientResponseEntityTests { {"name":"John", "age":30} """)), metadata); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity responseEntity = ChatClient.builder(chatModel) + ResponseEntity responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about John") @@ -77,7 +74,7 @@ public class ChatClientResponseEntityTests { assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30)); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about John"); } @@ -87,26 +84,27 @@ public class ChatClientResponseEntityTests { var chatResponse = new ChatResponse(List.of(new Generation(""" [ - {"name":"Max", "age":10}, - {"name":"Adi", "age":13} + {"name":"Max", "age":10}, + {"name":"Adi", "age":13} ] """))); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity> responseEntity = ChatClient.builder(chatModel) + ResponseEntity> responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about them") .call() .responseEntity(new ParameterizedTypeReference>() { + }); assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10)); assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13)); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about them"); } @@ -115,12 +113,12 @@ public class ChatClientResponseEntityTests { public void customSoCResponseEntityTest() { var chatResponse = new ChatResponse(List.of(new Generation(""" - {"name":"Max", "age":10}, + {"name":"Max", "age":10}, """))); - when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse); - ResponseEntity> responseEntity = ChatClient.builder(chatModel) + ResponseEntity> responseEntity = ChatClient.builder(this.chatModel) .build() .prompt() .user("Tell me about Max") @@ -131,9 +129,13 @@ public class ChatClientResponseEntityTests { assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max"); assertThat(responseEntity.getEntity().get("age")).isEqualTo(10); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).contains("Tell me about Max"); } + record MyBean(String name, int age) { + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 7033cdd5a..3dc853b2f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -54,6 +54,14 @@ import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class ChatClientTest { + static Function mockFunction = new Function() { + + @Override + public String apply(String s) { + return s; + } + }; + @Mock ChatModel chatModel; @@ -68,23 +76,23 @@ public class ChatClientTest { @Test public void defaultSystemText() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); - var chatClient = ChatClient.builder(chatModel).defaultSystem("Default system text").build(); + var chatClient = ChatClient.builder(this.chatModel).defaultSystem("Default system text").build(); var content = chatClient.prompt().call().content(); assertThat(content).isEqualTo("response"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -92,7 +100,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -100,7 +108,7 @@ public class ChatClientTest { content = chatClient.prompt().system("Override default system text").call().content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -108,7 +116,7 @@ public class ChatClientTest { content = join(chatClient.prompt().system("Override default system text").stream().content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @@ -116,17 +124,17 @@ public class ChatClientTest { @Test public void defaultSystemTextLambda() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -136,7 +144,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -145,7 +153,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -153,7 +161,7 @@ public class ChatClientTest { content = chatClient.prompt().system(s -> s.param("param1", "value1New")).call().content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1New, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -161,7 +169,7 @@ public class ChatClientTest { content = join(chatClient.prompt().system(s -> s.param("param1", "value1New")).stream().content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Default system text value1New, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -172,7 +180,7 @@ public class ChatClientTest { .content(); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -183,28 +191,21 @@ public class ChatClientTest { .content()); assertThat(content).isEqualTo("response"); - systemMessage = promptCaptor.getValue().getInstructions().get(0); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } - static Function mockFunction = new Function() { - @Override - public String apply(String s) { - return s; - } - }; - @Test public void mutateDefaults() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); @@ -212,7 +213,7 @@ public class ChatClientTest { })); // @formatter:off - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -230,7 +231,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - Prompt prompt = promptCaptor.getValue(); + Prompt prompt = this.promptCaptor.getValue(); Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -252,7 +253,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -282,7 +283,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -304,7 +305,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -327,19 +328,19 @@ public class ChatClientTest { public void mutatePrompt() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { sink.next(state); sink.complete(); return state; })); // @formatter:off - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) @@ -364,7 +365,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - Prompt prompt = promptCaptor.getValue(); + Prompt prompt = this.promptCaptor.getValue(); Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -395,7 +396,7 @@ public class ChatClientTest { assertThat(content).isEqualTo("response"); - prompt = promptCaptor.getValue(); + prompt = this.promptCaptor.getValue(); systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); @@ -416,16 +417,16 @@ public class ChatClientTest { @Test public void defaultUserText() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - var chatClient = ChatClient.builder(chatModel).defaultUser("Default user text").build(); + var chatClient = ChatClient.builder(this.chatModel).defaultUser("Default user text").build(); var content = chatClient.prompt().call().content(); assertThat(content).isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("Default user text"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); @@ -433,50 +434,51 @@ public class ChatClientTest { content = chatClient.prompt().user("Override default user text").call().content(); assertThat(content).isEqualTo("response"); - userMessage = promptCaptor.getValue().getInstructions().get(0); + userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("Override default user text"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPromptAsString() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - assertThat(ChatClient.builder(chatModel).build().prompt("User prompt").call().content()).isEqualTo("response"); + assertThat(ChatClient.builder(this.chatModel).build().prompt("User prompt").call().content()) + .isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPrompt() { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - assertThat(ChatClient.builder(chatModel).build().prompt().user("User prompt").call().content()) + assertThat(ChatClient.builder(this.chatModel).build().prompt().user("User prompt").call().content()) .isEqualTo("response"); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); } @Test public void simpleUserPromptObject() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); var media = new Media(MimeTypeUtils.IMAGE_JPEG, new DefaultResourceLoader().getResource("classpath:/bikes.json")); UserMessage message = new UserMessage("User prompt", List.of(media)); Prompt prompt = new Prompt(message); - assertThat(ChatClient.builder(chatModel).build().prompt(prompt).call().content()).isEqualTo("response"); + assertThat(ChatClient.builder(this.chatModel).build().prompt(prompt).call().content()).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); - Message userMessage = promptCaptor.getValue().getInstructions().get(0); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getContent()).isEqualTo("User prompt"); assertThat(((UserMessage) userMessage).getMedia()).hasSize(1); @@ -484,32 +486,32 @@ public class ChatClientTest { @Test public void simpleSystemPrompt() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - String response = ChatClient.builder(chatModel).build().prompt().system("System prompt").call().content(); + String response = ChatClient.builder(this.chatModel).build().prompt().system("System prompt").call().content(); assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(1); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System prompt"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @Test public void complexCall() throws MalformedURLException { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); var options = FunctionCallingOptions.builder().build(); - when(chatModel.getDefaultOptions()).thenReturn(options); + given(this.chatModel.getDefaultOptions()).willReturn(options); var url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); // @formatter:off - ChatClient client = ChatClient.builder(chatModel) + ChatClient client = ChatClient.builder(this.chatModel) .defaultSystem("System text") .defaultFunctions("function1") .build(); @@ -521,13 +523,13 @@ public class ChatClientTest { // @formatter:on assertThat(response).isEqualTo("response"); - assertThat(promptCaptor.getValue().getInstructions()).hasSize(2); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - UserMessage userMessage = (UserMessage) promptCaptor.getValue().getInstructions().get(1); + UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualTo("User text Rock"); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getMedia()).hasSize(1); @@ -535,7 +537,7 @@ public class ChatClientTest { assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); - FunctionCallingOptions runtieOptions = (FunctionCallingOptions) promptCaptor.getValue().getOptions(); + FunctionCallingOptions runtieOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); assertThat(runtieOptions.getFunctions()).containsExactly("function1"); assertThat(options.getFunctions()).isEmpty(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java index 9a8abcfce..fd92eb230 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.verify; - import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -31,11 +27,13 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; @@ -44,7 +42,9 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; /** * @author Christian Tzolov @@ -58,76 +58,6 @@ public class AdvisorsTests { @Captor ArgumentCaptor promptCaptor; - public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - - public AdvisedRequest advisedRequest; - - public AdvisedResponse advisedResponse; - - public List aroundAdvisedResponses = new ArrayList<>(); - - private final String name; - - private final int order; - - public MockAroundAdvisor(String name, int order) { - this.name = name; - this.order = order; - } - - @Override - public String getName() { - return name; - } - - @Override - public int getOrder() { - return order; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName()); - context.put("lastBefore", getName()); - return context; - }); - - AdvisedResponse advisedResponse = this.advisedResponse = chain.nextAroundCall(this.advisedRequest); - - this.advisedResponse = advisedResponse.updateContext(context -> { - context.put("aroundCallAfter" + name, "AROUND_CALL_AFTER " + name); - context.put("lastAfter", name); - return context; - }); - - return this.advisedResponse; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundStreamBefore" + name, "AROUND_STREAM_BEFORE " + name); - context.put("lastBefore", name); - return context; - }); - - Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); - - return advisedResponseStream.map(advisedResponse -> { - return advisedResponse.updateContext(context -> { - context.put("aroundStreamAfter" + name, "AROUND_STREAM_AFTER " + name); - context.put("lastAfter", name); - return context; - }); - }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); - - } - - } - @Test public void callAdvisorsContextPropagation() { @@ -136,10 +66,10 @@ public class AdvisorsTests { var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(mockAroundAdvisor1) .build(); @@ -164,7 +94,7 @@ public class AdvisorsTests { .containsEntry("lastBefore", "Advisor2") // inner .containsEntry("lastAfter", "Advisor1"); // outer - verify(chatModel).call(promptCaptor.capture()); + verify(this.chatModel).call(this.promptCaptor.capture()); } @Test @@ -173,11 +103,11 @@ public class AdvisorsTests { var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); - when(chatModel.stream(promptCaptor.capture())) - .thenReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), + given(this.chatModel.stream(this.promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(mockAroundAdvisor1) .build(); @@ -198,18 +128,89 @@ public class AdvisorsTests { // AROUND assertThat(mockAroundAdvisor1.aroundAdvisedResponses).isNotEmpty(); - mockAroundAdvisor1.aroundAdvisedResponses.stream().forEach(advisedResponse -> { - assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") + mockAroundAdvisor1.aroundAdvisedResponses.stream() + .forEach(advisedResponse -> assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") .containsEntry("key2", "value2") .containsEntry("aroundStreamBeforeAdvisor1", "AROUND_STREAM_BEFORE Advisor1") .containsEntry("aroundStreamAfterAdvisor1", "AROUND_STREAM_AFTER Advisor1") .containsEntry("aroundStreamBeforeAdvisor2", "AROUND_STREAM_BEFORE Advisor2") .containsEntry("aroundStreamAfterAdvisor2", "AROUND_STREAM_AFTER Advisor2") .containsEntry("lastBefore", "Advisor2") // inner - .containsEntry("lastAfter", "Advisor1"); // outer - }); + .containsEntry("lastAfter", "Advisor1") // outer + ); + + verify(this.chatModel).stream(this.promptCaptor.capture()); + } + + public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private final String name; + + private final int order; + + public AdvisedRequest advisedRequest; + + public AdvisedResponse advisedResponse; + + public List aroundAdvisedResponses = new ArrayList<>(); + + public MockAroundAdvisor(String name, int order) { + this.name = name; + this.order = order; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName()); + context.put("lastBefore", getName()); + return context; + }); + + this.advisedResponse = chain.nextAroundCall(this.advisedRequest); + AdvisedResponse advisedResponse = this.advisedResponse; + + this.advisedResponse = advisedResponse.updateContext(context -> { + context.put("aroundCallAfter" + this.name, "AROUND_CALL_AFTER " + this.name); + context.put("lastAfter", this.name); + return context; + }); + + return this.advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundStreamBefore" + this.name, "AROUND_STREAM_BEFORE " + this.name); + context.put("lastBefore", this.name); + return context; + }); + + Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); + + return advisedResponseStream.map(advisedResponse -> { + return advisedResponse.updateContext(context -> { + context.put("aroundStreamAfter" + this.name, "AROUND_STREAM_AFTER " + this.name); + context.put("lastAfter", this.name); + return context; + }); + }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); + + } - verify(chatModel).stream(promptCaptor.capture()); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index e75b1d04c..c63fa9ec2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.time.Duration; import java.util.List; import java.util.Map; @@ -29,6 +26,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -45,6 +43,9 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + /** * @author Christian Tzolov */ @@ -67,24 +68,24 @@ public class QuestionAnswerAdvisorTests { public void qaAdvisorWithDynamicFilterExpressions() { // @formatter:off - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), ChatResponseMetadata.builder() .withId("678") .withModel("model1") .withKeyValue("key6", "value6") - .withMetadata(Map.of("key1","value1" )) + .withMetadata(Map.of("key1", "value1")) .withPromptMetadata(null) .withRateLimit(new RateLimit() { @Override public Long getRequestsLimit() { - return 5l; + return 5L; } @Override public Long getRequestsRemaining() { - return 6l; + return 6L; } @Override @@ -94,12 +95,12 @@ public class QuestionAnswerAdvisorTests { @Override public Long getTokensLimit() { - return 8l; + return 8L; } @Override public Long getTokensRemaining() { - return 8l; + return 8L; } @Override @@ -107,17 +108,17 @@ public class QuestionAnswerAdvisorTests { return Duration.ofSeconds(9); } }) - .withUsage(new DefaultUsage(6l, 7l)) + .withUsage(new DefaultUsage(6L, 7L)) .build())); // @formatter:on - when(vectorStore.similaritySearch(vectorSearchCaptor.capture())) - .thenReturn(List.of(new Document("doc1"), new Document("doc2"))); + given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) + .willReturn(List.of(new Document("doc1"), new Document("doc2"))); - var qaAdvisor = new QuestionAnswerAdvisor(vectorStore, + var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.defaults().withSimilarityThreshold(0.99d).withTopK(6)); - var chatClient = ChatClient.builder(chatModel) + var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") .defaultAdvisors(qaAdvisor) .build(); @@ -133,26 +134,23 @@ public class QuestionAnswerAdvisorTests { // Ensure the metadata is correctly copied over assertThat(response.getMetadata().getModel()).isEqualTo("model1"); assertThat(response.getMetadata().getId()).isEqualTo("678"); - assertThat(response.getMetadata().getRateLimit().getRequestsLimit()).isEqualTo(5l); - assertThat(response.getMetadata().getRateLimit().getRequestsRemaining()).isEqualTo(6l); + assertThat(response.getMetadata().getRateLimit().getRequestsLimit()).isEqualTo(5L); + assertThat(response.getMetadata().getRateLimit().getRequestsRemaining()).isEqualTo(6L); assertThat(response.getMetadata().getRateLimit().getRequestsReset()).isEqualTo(Duration.ofSeconds(7)); - assertThat(response.getMetadata().getRateLimit().getTokensLimit()).isEqualTo(8l); - assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8l); + assertThat(response.getMetadata().getRateLimit().getTokensLimit()).isEqualTo(8L); + assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8L); assertThat(response.getMetadata().getRateLimit().getTokensReset()).isEqualTo(Duration.ofSeconds(9)); - assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6l); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7l); - assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6l + 7l); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6L); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7L); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6L + 7L); assertThat(response.getMetadata().get("key6").toString()).isEqualTo("value6"); assertThat(response.getMetadata().get("key1").toString()).isEqualTo("value1"); - - - String content = response.getResult().getOutput().getContent(); assertThat(content).isEqualTo("Your answer is ZXY"); - Message systemMessage = promptCaptor.getValue().getInstructions().get(0); + Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); System.out.println(systemMessage.getContent()); @@ -161,7 +159,7 @@ public class QuestionAnswerAdvisorTests { """); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - Message userMessage = promptCaptor.getValue().getInstructions().get(1); + Message userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace(""" Please answer my question XYZ @@ -177,9 +175,9 @@ public class QuestionAnswerAdvisorTests { the user that you can't answer the question. """); - assertThat(vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build()); - assertThat(vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); - assertThat(vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); + assertThat(this.vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build()); + assertThat(this.vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); + assertThat(this.vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java index ee864a22d..b12e1b24f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.stream.Collectors; @@ -28,6 +25,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -39,7 +38,8 @@ import org.springframework.boot.test.system.CapturedOutput; import org.springframework.boot.test.system.OutputCaptureExtension; import org.springframework.test.context.ActiveProfiles; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -57,12 +57,12 @@ public class SimpleLoggerAdvisorTests { @Test public void callLogging(CapturedOutput output) { - when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))))); + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))))); var loggerAdvisor = new SimpleLoggerAdvisor(); - var chatClient = ChatClient.builder(chatModel).defaultAdvisors(loggerAdvisor).build(); + var chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(loggerAdvisor).build(); var content = chatClient.prompt().user("Please answer my question XYZ").call().content(); @@ -72,7 +72,7 @@ public class SimpleLoggerAdvisorTests { @Test public void streamLogging(CapturedOutput output) { - when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY")))), (state, sink) -> { sink.next(state); @@ -82,7 +82,7 @@ public class SimpleLoggerAdvisorTests { var loggerAdvisor = new SimpleLoggerAdvisor(); - var chatClient = ChatClient.builder(chatModel).defaultAdvisors(loggerAdvisor).build(); + var chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(loggerAdvisor).build(); String content = join(chatClient.prompt().user("Please answer my question XYZ").stream().content()); @@ -100,7 +100,7 @@ public class SimpleLoggerAdvisorTests { private void validate(String content, CapturedOutput output) { assertThat(content).isEqualTo("Your answer is ZXY"); - UserMessage userMessage = (UserMessage) promptCaptor.getValue().getInstructions().get(0); + UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("Please answer my question XYZ"); assertThat(output.getOut()).contains("request: AdvisedRequest", "userText=Please answer my question XYZ"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java index 795d1056f..9e12573fa 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor.observation; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import org.junit.jupiter.api.Test; - /** * Unit tests for {@link AdvisorObservationContext}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java index 7f9f4e3da..906a33768 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.client.advisor.observation; -import static org.assertj.core.api.Assertions.assertThat; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames; - import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.observation.conventions.SpringAiKind; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link DefaultAdvisorObservationConvention}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 2b189fdf7..31d017d74 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.model.ChatModel; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link ChatClientInputContentObservationFilter}. @@ -43,29 +44,29 @@ class ChatClientInputContentObservationFilterTests { private final ChatClientInputContentObservationFilter observationFilter = new ChatClientInputContentObservationFilter(); + @Mock + ChatModel chatModel; + @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } - @Mock - ChatModel chatModel; - @Test void whenEmptyInputContentThenReturnOriginalContext() { ObservationRegistry observationRegistry = ObservationRegistry.NOOP; ChatClientObservationConvention customObservationConvention = null; - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); var expectedContext = ChatClientObservationContext.builder().withRequest(request).build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -75,13 +76,13 @@ class ChatClientInputContentObservationFilterTests { ObservationRegistry observationRegistry = ObservationRegistry.NOOP; ChatClientObservationConvention customObservationConvention = null; - var request = new DefaultChatClientRequestSpec(chatModel, "sample user text", Map.of("up1", "upv1"), + var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"), "sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); var originalContext = ChatClientObservationContext.builder().withRequest(request).build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT.asString(), "sample user text")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java index 0cc401e87..cf8f64424 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.model.ChatModel; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link ChatClientObservationContext}. @@ -44,7 +45,7 @@ class ChatClientObservationContextTests { @Test void whenMandatoryRequestOptionsThenReturn() { - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 5809cc19d..0f1e48142 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chat.client.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chat.client.observation; import java.util.List; import java.util.Map; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; @@ -36,9 +39,7 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link DefaultChatClientObservationConvention}. @@ -49,60 +50,16 @@ import io.micrometer.observation.ObservationRegistry; @ExtendWith(MockitoExtension.class) class DefaultChatClientObservationConventionTests { + private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention(); + @Mock ChatModel chatModel; - private final DefaultChatClientObservationConvention observationConvention = new DefaultChatClientObservationConvention(); - DefaultChatClientRequestSpec request; - @BeforeEach - public void beforeEach() { - request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); - } - - @Test - void shouldHaveName() { - assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME); - } - - @Test - void shouldHaveContextualName() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.getContextualName(observationContext)) - .isEqualTo("%s %s".formatted(AiProvider.SPRING_AI.value(), SpringAiKind.CHAT_CLIENT.value())); - } - - @Test - void supportsOnlyChatClientObservationContext() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); - assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); - } - - @Test - void shouldHaveRequiredKeyValues() { - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .withRequest(request) - .withStream(true) - .build(); - - assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( - KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat_client"), - KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true")); - } - static RequestResponseAdvisor dummyAdvisor(String name) { return new RequestResponseAdvisor() { + @Override public String getName() { return name; @@ -128,6 +85,7 @@ class DefaultChatClientObservationConventionTests { static FunctionCallback dummyFunction(String name) { return new FunctionCallback() { + @Override public String getName() { return name; @@ -153,9 +111,54 @@ class DefaultChatClientObservationConventionTests { }; } + @BeforeEach + public void beforeEach() { + this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), + List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); + } + + @Test + void shouldHaveName() { + assertThat(this.observationConvention.getName()).isEqualTo(DefaultChatClientObservationConvention.DEFAULT_NAME); + } + + @Test + void shouldHaveContextualName() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.getContextualName(observationContext)) + .isEqualTo("%s %s".formatted(AiProvider.SPRING_AI.value(), SpringAiKind.CHAT_CLIENT.value())); + } + + @Test + void supportsOnlyChatClientObservationContext() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); + assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); + } + + @Test + void shouldHaveRequiredKeyValues() { + ChatClientObservationContext observationContext = ChatClientObservationContext.builder() + .withRequest(this.request) + .withStream(true) + .build(); + + assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( + KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat_client"), + KeyValue.of(LowCardinalityKeyNames.STREAM.asString(), "true")); + } + @Test void shouldHaveOptionalKeyValues() { - var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), + var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), List.of("function1", "function2"), List.of(), null, List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java index 68985f67b..8059faf95 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.metadata; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; public class DefaultUsageTests { @@ -26,14 +28,14 @@ public class DefaultUsageTests { @Test void testSerializationWithAllFields() throws Exception { DefaultUsage usage = new DefaultUsage(100L, 50L, 150L); - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); } @Test void testDeserializationWithAllFields() throws Exception { String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, usage.getPromptTokens()); assertEquals(50L, usage.getGenerationTokens()); assertEquals(150L, usage.getTotalTokens()); @@ -42,14 +44,14 @@ public class DefaultUsageTests { @Test void testSerializationWithNullFields() throws Exception { DefaultUsage usage = new DefaultUsage(null, null, null); - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); } @Test void testDeserializationWithMissingFields() throws Exception { String json = "{\"promptTokens\":100}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, usage.getPromptTokens()); assertEquals(0L, usage.getGenerationTokens()); assertEquals(100L, usage.getTotalTokens()); @@ -58,7 +60,7 @@ public class DefaultUsageTests { @Test void testDeserializationWithNullFields() throws Exception { String json = "{\"promptTokens\":null,\"generationTokens\":null,\"totalTokens\":null}"; - DefaultUsage usage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(0L, usage.getPromptTokens()); assertEquals(0L, usage.getGenerationTokens()); assertEquals(0L, usage.getTotalTokens()); @@ -67,8 +69,8 @@ public class DefaultUsageTests { @Test void testRoundTripSerialization() throws Exception { DefaultUsage original = new DefaultUsage(100L, 50L, 150L); - String json = objectMapper.writeValueAsString(original); - DefaultUsage deserialized = objectMapper.readValue(json, DefaultUsage.class); + String json = this.objectMapper.writeValueAsString(original); + DefaultUsage deserialized = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(original.getPromptTokens(), deserialized.getPromptTokens()); assertEquals(original.getGenerationTokens(), deserialized.getGenerationTokens()); assertEquals(original.getTotalTokens(), deserialized.getTotalTokens()); @@ -84,11 +86,11 @@ public class DefaultUsageTests { assertEquals(150L, usage.getTotalTokens()); // 100 + 50 = 150 // Test serialization - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); // Test deserialization - DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(100L, deserializedUsage.getPromptTokens()); assertEquals(50L, deserializedUsage.getGenerationTokens()); assertEquals(150L, deserializedUsage.getTotalTokens()); @@ -104,11 +106,11 @@ public class DefaultUsageTests { assertEquals(0L, usage.getTotalTokens()); // Test serialization - String json = objectMapper.writeValueAsString(usage); + String json = this.objectMapper.writeValueAsString(usage); assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); // Test deserialization - DefaultUsage deserializedUsage = objectMapper.readValue(json, DefaultUsage.class); + DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(0L, deserializedUsage.getPromptTokens()); assertEquals(0L, deserializedUsage.getGenerationTokens()); assertEquals(0L, deserializedUsage.getTotalTokens()); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java index 4bcf3344e..b5e173e1d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/GenerationTests.java @@ -1,9 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.model; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; @@ -41,9 +58,9 @@ public class GenerationTests { @Test void testConstructorWithMetadata() { AssistantMessage assistantMessage = new AssistantMessage("Test Assistant Message"); - Generation generation = new Generation(assistantMessage, mockChatGenerationMetadata1); + Generation generation = new Generation(assistantMessage, this.mockChatGenerationMetadata1); - assertEquals(mockChatGenerationMetadata1, generation.getMetadata()); + assertEquals(this.mockChatGenerationMetadata1, generation.getMetadata()); } @Test @@ -58,10 +75,10 @@ public class GenerationTests { @Test void testGetMetadata_NotNull() { AssistantMessage assistantMessage = new AssistantMessage("Test Assistant Message"); - Generation generation = new Generation(assistantMessage, mockChatGenerationMetadata1); + Generation generation = new Generation(assistantMessage, this.mockChatGenerationMetadata1); ChatGenerationMetadata metadata = generation.getMetadata(); - assertEquals(mockChatGenerationMetadata1, metadata); + assertEquals(this.mockChatGenerationMetadata1, metadata); } @Test @@ -86,8 +103,8 @@ public class GenerationTests { void testEquals_SameMetadata() { AssistantMessage assistantMessage1 = new AssistantMessage("Test Assistant Message"); AssistantMessage assistantMessage2 = new AssistantMessage("Test Assistant Message"); - Generation generation1 = new Generation(assistantMessage1, mockChatGenerationMetadata1); - Generation generation2 = new Generation(assistantMessage2, mockChatGenerationMetadata1); + Generation generation1 = new Generation(assistantMessage1, this.mockChatGenerationMetadata1); + Generation generation2 = new Generation(assistantMessage2, this.mockChatGenerationMetadata1); assertTrue(generation1.equals(generation2)); } @@ -96,8 +113,8 @@ public class GenerationTests { void testEquals_DifferentMetadata() { AssistantMessage assistantMessage1 = new AssistantMessage("Test Assistant Message"); AssistantMessage assistantMessage2 = new AssistantMessage("Test Assistant Message"); - Generation generation1 = new Generation(assistantMessage1, mockChatGenerationMetadata1); - Generation generation2 = new Generation(assistantMessage2, mockChatGenerationMetadata2); + Generation generation1 = new Generation(assistantMessage1, this.mockChatGenerationMetadata1); + Generation generation2 = new Generation(assistantMessage2, this.mockChatGenerationMetadata2); assertFalse(generation1.equals(generation2)); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java index 0276568dd..2ee37ac28 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -41,7 +43,7 @@ class ChatModelCompletionObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -53,7 +55,7 @@ class ChatModelCompletionObservationFilterTests { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -66,7 +68,7 @@ class ChatModelCompletionObservationFilterTests { .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); expectedContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -80,7 +82,7 @@ class ChatModelCompletionObservationFilterTests { .build(); originalContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains(KeyValue .of(HighCardinalityKeyNames.COMPLETION.asString(), "[\"say please\", \"seriously, say please\"]")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java index f5b12e553..225fcbce5 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; @@ -22,6 +25,7 @@ import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -31,8 +35,6 @@ import org.springframework.ai.observation.conventions.AiObservationAttributes; import org.springframework.ai.observation.conventions.AiObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index 08edea5d1..cf097fc89 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.Usage; @@ -28,9 +32,10 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.conventions.*; - -import java.util.List; +import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; +import org.springframework.ai.observation.conventions.AiObservationMetricNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiTokenType; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; @@ -59,7 +64,7 @@ class ChatModelMeterObservationHandlerTests { var observationContext = generateObservationContext(); var observation = Observation .createNotStarted(new DefaultChatModelObservationConvention(), () -> observationContext, - observationRegistry) + this.observationRegistry) .start(); observationContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("test"))), @@ -67,20 +72,20 @@ class ChatModelMeterObservationHandlerTests { observation.stop(); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.INPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.OUTPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.TOTAL.value()) .meters()).hasSize(1); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java index e723b9126..a7c62a462 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java index 92d9e0d8b..8e33c73e0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -40,7 +42,7 @@ class ChatModelPromptContentObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -52,7 +54,7 @@ class ChatModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -64,7 +66,7 @@ class ChatModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains( KeyValue.of(HighCardinalityKeyNames.PROMPT.asString(), "[\"supercalifragilisticexpialidocious\"]")); @@ -78,7 +80,7 @@ class ChatModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(HighCardinalityKeyNames.PROMPT.asString(), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java index 375598244..4064d9570 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; import io.micrometer.tracing.handler.TracingObservationHandler; @@ -22,6 +23,7 @@ import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index 42164fa94..929637d78 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -27,8 +31,6 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java index 14627862d..020743958 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.time.LocalDate; @@ -44,41 +45,95 @@ class BeanOutputConverterTest { private ObjectMapper objectMapperMock; @Test - public void shouldHavePreConfiguredDefaultObjectMapper() { + void shouldHavePreConfiguredDefaultObjectMapper() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var objectMapper = converter.getObjectMapper(); assertThat(objectMapper.isEnabled(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse(); } + static class TestClass { + + private String someString; + + @SuppressWarnings("unused") + TestClass() { + } + + TestClass(String someString) { + this.someString = someString; + } + + String getSomeString() { + return this.someString; + } + + } + + static class TestClassWithDateProperty { + + private LocalDate someString; + + @SuppressWarnings("unused") + TestClassWithDateProperty() { + } + + TestClassWithDateProperty(LocalDate someString) { + this.someString = someString; + } + + LocalDate getSomeString() { + return this.someString; + } + + } + + static class TestClassWithJsonAnnotations { + + @JsonProperty("string_property") + @JsonPropertyDescription("string_property_description") + private String someString; + + TestClassWithJsonAnnotations() { + } + + String getSomeString() { + return this.someString; + } + + } + @Nested class ConverterTest { @Test - public void convertClassType() { + void convertClassType() { var converter = new BeanOutputConverter<>(TestClass.class); var testClass = converter.convert("{ \"someString\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertClassWithDateType() { + void convertClassWithDateType() { var converter = new BeanOutputConverter<>(TestClassWithDateProperty.class); var testClass = converter.convert("{ \"someString\": \"2020-01-01\" }"); assertThat(testClass.getSomeString()).isEqualTo(LocalDate.of(2020, 1, 1)); } @Test - public void convertTypeReference() { + void convertTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var testClass = converter.convert("{ \"someString\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceArray() { + void convertTypeReferenceArray() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); List testClass = converter.convert("[{ \"someString\": \"some value\" }]"); assertThat(testClass).hasSize(1); @@ -86,24 +141,26 @@ class BeanOutputConverterTest { } @Test - public void convertClassTypeWithJsonAnnotations() { + void convertClassTypeWithJsonAnnotations() { var converter = new BeanOutputConverter<>(TestClassWithJsonAnnotations.class); var testClass = converter.convert("{ \"string_property\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceWithJsonAnnotations() { + void convertTypeReferenceWithJsonAnnotations() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); var testClass = converter.convert("{ \"string_property\": \"some value\" }"); assertThat(testClass.getSomeString()).isEqualTo("some value"); } @Test - public void convertTypeReferenceArrayWithJsonAnnotations() { + void convertTypeReferenceArrayWithJsonAnnotations() { var converter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { + }); List testClass = converter .convert("[{ \"string_property\": \"some value\" }]"); @@ -113,11 +170,12 @@ class BeanOutputConverterTest { } + // @checkstyle:off RegexpSinglelineJavaCheck @Nested class FormatTest { @Test - public void formatClassType() { + void formatClassType() { var converter = new BeanOutputConverter<>(TestClass.class); assertThat(converter.getFormat()).isEqualTo( """ @@ -140,8 +198,9 @@ class BeanOutputConverterTest { } @Test - public void formatTypeReference() { + void formatTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); assertThat(converter.getFormat()).isEqualTo( """ @@ -164,8 +223,9 @@ class BeanOutputConverterTest { } @Test - public void formatTypeReferenceArray() { + void formatTypeReferenceArray() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference>() { + }); assertThat(converter.getFormat()).isEqualTo( """ @@ -191,7 +251,7 @@ class BeanOutputConverterTest { } @Test - public void formatClassTypeWithAnnotations() { + void formatClassTypeWithAnnotations() { var converter = new BeanOutputConverter<>(TestClassWithJsonAnnotations.class); assertThat(converter.getFormat()).contains(""" ```{ @@ -209,8 +269,9 @@ class BeanOutputConverterTest { } @Test - public void formatTypeReferenceWithAnnotations() { + void formatTypeReferenceWithAnnotations() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); assertThat(converter.getFormat()).contains(""" ```{ @@ -226,6 +287,7 @@ class BeanOutputConverterTest { }``` """); } + // @checkstyle:on RegexpSinglelineJavaCheck @Test void normalizesLineEndingsClassType() { @@ -240,6 +302,7 @@ class BeanOutputConverterTest { @Test void normalizesLineEndingsTypeReference() { var converter = new BeanOutputConverter<>(new ParameterizedTypeReference() { + }); String formatOutput = converter.getFormat(); @@ -250,55 +313,4 @@ class BeanOutputConverterTest { } - public static class TestClass { - - private String someString; - - @SuppressWarnings("unused") - public TestClass() { - } - - public TestClass(String someString) { - this.someString = someString; - } - - public String getSomeString() { - return someString; - } - - } - - public static class TestClassWithDateProperty { - - private LocalDate someString; - - @SuppressWarnings("unused") - public TestClassWithDateProperty() { - } - - public TestClassWithDateProperty(LocalDate someString) { - this.someString = someString; - } - - public LocalDate getSomeString() { - return someString; - } - - } - - public static class TestClassWithJsonAnnotations { - - @JsonProperty("string_property") - @JsonPropertyDescription("string_property_description") - private String someString; - - public TestClassWithJsonAnnotations() { - } - - public String getSomeString() { - return someString; - } - - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java index 4c0795a1c..f63f1e2fe 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/converter/ListOutputConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.converter; import java.util.List; @@ -33,4 +34,4 @@ class ListOutputConverterTest { assertThat(list).containsExactlyElementsOf(List.of("foo", "bar", "baz")); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java index b20d7595b..5437d965c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/ContentFormatterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document; import java.util.Map; @@ -31,12 +32,12 @@ public class ContentFormatterTests { @Test public void noExplicitlySetFormatter() { - assertThat(document.getContent()).isEqualTo(""" + assertThat(this.document.getContent()).isEqualTo(""" The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent()).isEqualTo(document.getFormattedContent(MetadataMode.ALL)); - assertThat(document.getFormattedContent()) - .isEqualTo(document.getFormattedContent(Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.ALL)); + assertThat(this.document.getFormattedContent()).isEqualTo(this.document.getFormattedContent(MetadataMode.ALL)); + assertThat(this.document.getFormattedContent()) + .isEqualTo(this.document.getFormattedContent(Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.ALL)); } @@ -45,7 +46,7 @@ public class ContentFormatterTests { DefaultContentFormatter defaultConfigFormatter = DefaultContentFormatter.defaultConfig(); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)).isEqualTo(""" + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)).isEqualTo(""" llmKey2: value4 embedKey1: value1 embedKey2: value2 @@ -53,11 +54,11 @@ public class ContentFormatterTests { The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) - .isEqualTo(document.getFormattedContent()); + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) + .isEqualTo(this.document.getFormattedContent()); - assertThat(document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) - .isEqualTo(defaultConfigFormatter.format(document, MetadataMode.ALL)); + assertThat(this.document.getFormattedContent(defaultConfigFormatter, MetadataMode.ALL)) + .isEqualTo(defaultConfigFormatter.format(this.document, MetadataMode.ALL)); } @Test @@ -70,23 +71,24 @@ public class ContentFormatterTests { .withMetadataTemplate("Key/Value {key}={value}") .build(); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.EMBED)).isEqualTo(""" + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.EMBED)).isEqualTo(""" Metadata: Key/Value llmKey2=value4 Key/Value embedKey1=value1 Text:The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getContent()).isEqualTo(""" + assertThat(this.document.getContent()).isEqualTo(""" The World is Big and Salvation Lurks Around the Corner"""); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.EMBED)) - .isEqualTo(textFormatter.format(document, MetadataMode.EMBED)); + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.EMBED)) + .isEqualTo(textFormatter.format(this.document, MetadataMode.EMBED)); - var documentWithCustomFormatter = new Document(document.getId(), document.getContent(), document.getMetadata()); + var documentWithCustomFormatter = new Document(this.document.getId(), this.document.getContent(), + this.document.getMetadata()); documentWithCustomFormatter.setContentFormatter(textFormatter); - assertThat(document.getFormattedContent(textFormatter, MetadataMode.ALL)) + assertThat(this.document.getFormattedContent(textFormatter, MetadataMode.ALL)) .isEqualTo(documentWithCustomFormatter.getFormattedContent()); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java index ee745d9a1..ebaeef389 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.document; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.ai.model.Media; -import org.springframework.ai.document.id.IdGenerator; -import org.springframework.util.MimeTypeUtils; +package org.springframework.ai.document; import java.net.MalformedURLException; import java.net.URL; @@ -27,6 +22,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.document.id.IdGenerator; +import org.springframework.ai.model.Media; +import org.springframework.util.MimeTypeUtils; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -34,160 +36,6 @@ public class DocumentBuilderTests { private Document.Builder builder; - @BeforeEach - void setUp() { - builder = Document.builder(); - } - - @Test - void testWithIdGenerator() { - IdGenerator mockGenerator = new IdGenerator() { - @Override - public String generateId(Object... contents) { - return "mockedId"; - } - }; - - Document.Builder result = builder.withIdGenerator(mockGenerator); - - assertThat(result).isSameAs(builder); - - Document document = result.withContent("Test content").withMetadata("key", "value").build(); - - assertThat(document.getId()).isEqualTo("mockedId"); - } - - @Test - void testWithIdGeneratorNull() { - assertThatThrownBy(() -> builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("idGenerator must not be null"); - } - - @Test - void testWithId() { - Document.Builder result = builder.withId("testId"); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getId()).isEqualTo("testId"); - } - - @Test - void testWithIdNullOrEmpty() { - assertThatThrownBy(() -> builder.withId(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); - - assertThatThrownBy(() -> builder.withId("")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); - } - - @Test - void testWithContent() { - Document.Builder result = builder.withContent("Test content"); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getContent()).isEqualTo("Test content"); - } - - @Test - void testWithContentNull() { - assertThatThrownBy(() -> builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("content must not be null"); - } - - @Test - void testWithMediaList() { - List mediaList = getMediaList(); - Document.Builder result = builder.withMedia(mediaList); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getMedia()).isEqualTo(mediaList); - } - - @Test - void testWithMediaListNull() { - assertThatThrownBy(() -> builder.withMedia((List) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); - } - - @Test - void testWithMediaSingle() throws MalformedURLException { - URL mediaUrl = new URL("http://test"); - Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); - - Document.Builder result = builder.withMedia(media); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getMedia()).contains(media); - } - - @Test - void testWithMediaSingleNull() { - assertThatThrownBy(() -> builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); - } - - @Test - void testWithMetadataMap() { - Map metadata = new HashMap<>(); - metadata.put("key1", "value1"); - metadata.put("key2", 2); - Document.Builder result = builder.withMetadata(metadata); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getMetadata()).isEqualTo(metadata); - } - - @Test - void testWithMetadataMapNull() { - assertThatThrownBy(() -> builder.withMetadata((Map) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("metadata must not be null"); - } - - @Test - void testWithMetadataKeyValue() { - Document.Builder result = builder.withMetadata("key", "value"); - - assertThat(result).isSameAs(builder); - assertThat(result.build().getMetadata()).containsEntry("key", "value"); - } - - @Test - void testWithMetadataKeyValueNull() { - assertThatThrownBy(() -> builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("key must not be null"); - - assertThatThrownBy(() -> builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("value must not be null"); - } - - @Test - void testBuildWithoutId() { - Document document = builder.withContent("Test content").build(); - - assertThat(document.getId()).isNotNull().isNotEmpty(); - assertThat(document.getContent()).isEqualTo("Test content"); - } - - @Test - void testBuildWithAllProperties() throws MalformedURLException { - - List mediaList = getMediaList(); - Map metadata = new HashMap<>(); - metadata.put("key", "value"); - - Document document = builder.withId("customId") - .withContent("Test content") - .withMedia(mediaList) - .withMetadata(metadata) - .build(); - - assertThat(document.getId()).isEqualTo("customId"); - assertThat(document.getContent()).isEqualTo("Test content"); - assertThat(document.getMedia()).isEqualTo(mediaList); - assertThat(document.getMetadata()).isEqualTo(metadata); - } - private static List getMediaList() { try { URL mediaUrl1 = new URL("http://type1"); @@ -203,4 +51,160 @@ public class DocumentBuilderTests { } + @BeforeEach + void setUp() { + this.builder = Document.builder(); + } + + @Test + void testWithIdGenerator() { + IdGenerator mockGenerator = new IdGenerator() { + + @Override + public String generateId(Object... contents) { + return "mockedId"; + } + }; + + Document.Builder result = this.builder.withIdGenerator(mockGenerator); + + assertThat(result).isSameAs(this.builder); + + Document document = result.withContent("Test content").withMetadata("key", "value").build(); + + assertThat(document.getId()).isEqualTo("mockedId"); + } + + @Test + void testWithIdGeneratorNull() { + assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("idGenerator must not be null"); + } + + @Test + void testWithId() { + Document.Builder result = this.builder.withId("testId"); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getId()).isEqualTo("testId"); + } + + @Test + void testWithIdNullOrEmpty() { + assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id must not be null or empty"); + + assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id must not be null or empty"); + } + + @Test + void testWithContent() { + Document.Builder result = this.builder.withContent("Test content"); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getContent()).isEqualTo("Test content"); + } + + @Test + void testWithContentNull() { + assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("content must not be null"); + } + + @Test + void testWithMediaList() { + List mediaList = getMediaList(); + Document.Builder result = this.builder.withMedia(mediaList); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getMedia()).isEqualTo(mediaList); + } + + @Test + void testWithMediaListNull() { + assertThatThrownBy(() -> this.builder.withMedia((List) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media must not be null"); + } + + @Test + void testWithMediaSingle() throws MalformedURLException { + URL mediaUrl = new URL("http://test"); + Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); + + Document.Builder result = this.builder.withMedia(media); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getMedia()).contains(media); + } + + @Test + void testWithMediaSingleNull() { + assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media must not be null"); + } + + @Test + void testWithMetadataMap() { + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", 2); + Document.Builder result = this.builder.withMetadata(metadata); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getMetadata()).isEqualTo(metadata); + } + + @Test + void testWithMetadataMapNull() { + assertThatThrownBy(() -> this.builder.withMetadata((Map) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata must not be null"); + } + + @Test + void testWithMetadataKeyValue() { + Document.Builder result = this.builder.withMetadata("key", "value"); + + assertThat(result).isSameAs(this.builder); + assertThat(result.build().getMetadata()).containsEntry("key", "value"); + } + + @Test + void testWithMetadataKeyValueNull() { + assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("key must not be null"); + + assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("value must not be null"); + } + + @Test + void testBuildWithoutId() { + Document document = this.builder.withContent("Test content").build(); + + assertThat(document.getId()).isNotNull().isNotEmpty(); + assertThat(document.getContent()).isEqualTo("Test content"); + } + + @Test + void testBuildWithAllProperties() throws MalformedURLException { + + List mediaList = getMediaList(); + Map metadata = new HashMap<>(); + metadata.put("key", "value"); + + Document document = this.builder.withId("customId") + .withContent("Test content") + .withMedia(mediaList) + .withMetadata(metadata) + .build(); + + assertThat(document.getId()).isEqualTo("customId"); + assertThat(document.getContent()).isEqualTo("Test content"); + assertThat(document.getMedia()).isEqualTo(mediaList); + assertThat(document.getMetadata()).isEqualTo(metadata); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java b/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java index 2e74d671b..5072e51cf 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/id/IdGeneratorProviderTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.util.Map; @@ -64,4 +65,4 @@ public class IdGeneratorProviderTest { Assertions.assertDoesNotThrow(() -> UUID.fromString(actualHashes2)); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java b/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java index 4fc94f62e..6d610fa2a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/id/JdkSha256HexIdGeneratorTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.document.id; import java.nio.charset.Charset; @@ -28,8 +29,8 @@ public class JdkSha256HexIdGeneratorTest { @Test void messageDigestReturnsDistinctInstances() { - final MessageDigest md1 = testee.getMessageDigest(); - final MessageDigest md2 = testee.getMessageDigest(); + final MessageDigest md1 = this.testee.getMessageDigest(); + final MessageDigest md2 = this.testee.getMessageDigest(); Assertions.assertThat(md1 != md2).isTrue(); @@ -45,10 +46,10 @@ public class JdkSha256HexIdGeneratorTest { final String updateString2 = "md2_update"; final Charset charset = StandardCharsets.UTF_8; - final byte[] md1BytesFirstTry = testee.getMessageDigest().digest(updateString1.getBytes(charset)); - final byte[] md2BytesFirstTry = testee.getMessageDigest().digest(updateString2.getBytes(charset)); - final byte[] md1BytesSecondTry = testee.getMessageDigest().digest(updateString1.getBytes(charset)); - final byte[] md2BytesSecondTry = testee.getMessageDigest().digest(updateString2.getBytes(charset)); + final byte[] md1BytesFirstTry = this.testee.getMessageDigest().digest(updateString1.getBytes(charset)); + final byte[] md2BytesFirstTry = this.testee.getMessageDigest().digest(updateString2.getBytes(charset)); + final byte[] md1BytesSecondTry = this.testee.getMessageDigest().digest(updateString1.getBytes(charset)); + final byte[] md2BytesSecondTry = this.testee.getMessageDigest().digest(updateString2.getBytes(charset)); Assertions.assertThat(md1BytesFirstTry).isNotEqualTo(md2BytesFirstTry); @@ -56,4 +57,4 @@ public class JdkSha256HexIdGeneratorTest { Assertions.assertThat(md2BytesFirstTry).isEqualTo(md2BytesSecondTry); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java index 88ff94632..c64a3cedf 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingModelTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding; import java.util.List; @@ -29,9 +30,9 @@ import org.springframework.ai.document.Document; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -79,16 +80,17 @@ public class AbstractEmbeddingModelTests { @ParameterizedTest @CsvFileSource(resources = "/embedding/embedding-model-dimensions.properties", numLinesToSkip = 1, delimiter = '=') public void testKnownEmbeddingModelDimensions(String model, String dimension) { - assertThat(AbstractEmbeddingModel.dimensions(embeddingModel, model, "Hello world!")) + assertThat(AbstractEmbeddingModel.dimensions(this.embeddingModel, model, "Hello world!")) .isEqualTo(Integer.valueOf(dimension)); - verify(embeddingModel, never()).embed(any(String.class)); - verify(embeddingModel, never()).embed(any(Document.class)); + verify(this.embeddingModel, never()).embed(any(String.class)); + verify(this.embeddingModel, never()).embed(any(Document.class)); } @Test public void testUnknownModelDimension() { - when(embeddingModel.embed(eq("Hello world!"))).thenReturn(new float[] { 0.1f, 0.1f, 0.1f }); - assertThat(AbstractEmbeddingModel.dimensions(embeddingModel, "unknown_model", "Hello world!")).isEqualTo(3); + given(this.embeddingModel.embed(eq("Hello world!"))).willReturn(new float[] { 0.1f, 0.1f, 0.1f }); + assertThat(AbstractEmbeddingModel.dimensions(this.embeddingModel, "unknown_model", "Hello world!")) + .isEqualTo(3); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java index f809ccf27..3d14f1e71 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.embedding; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +package org.springframework.ai.embedding; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -28,6 +26,9 @@ import org.springframework.ai.document.Document; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** * Basic unit test for {@link TokenCountBatchingStrategy}. * @@ -49,9 +50,8 @@ public class TokenCountBatchingStrategyTests { Resource resource = new DefaultResourceLoader().getResource("classpath:text_source.txt"); String contentAsString = resource.getContentAsString(StandardCharsets.UTF_8); TokenCountBatchingStrategy tokenCountBatchingStrategy = new TokenCountBatchingStrategy(); - assertThatThrownBy(() -> { - tokenCountBatchingStrategy.batch(List.of(new Document(contentAsString))); - }).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> tokenCountBatchingStrategy.batch(List.of(new Document(contentAsString)))) + .isInstanceOf(IllegalArgumentException.class); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java index f973d95eb..977c30a44 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; +import java.util.Map; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java index 560f37a55..a97afb9d1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,23 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; +import java.util.Map; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.observation.conventions.*; - -import java.util.List; -import java.util.Map; +import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; +import org.springframework.ai.observation.conventions.AiObservationMetricNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiTokenType; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; @@ -58,7 +63,7 @@ class EmbeddingModelMeterObservationHandlerTests { var observationContext = generateObservationContext(); var observation = Observation .createNotStarted(new DefaultEmbeddingModelObservationConvention(), () -> observationContext, - observationRegistry) + this.observationRegistry) .start(); observationContext.setResponse(new EmbeddingResponse(List.of(), @@ -66,20 +71,20 @@ class EmbeddingModelMeterObservationHandlerTests { observation.stop(); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.EMBEDDING.value()) .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.INPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.OUTPUT.value()) .meters()).hasSize(1); - assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) + assertThat(this.meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.TOTAL.value()) .meters()).hasSize(1); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java index 8c3bbb0cc..0678fe26a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.embedding.observation; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java index 5c8951de3..6681402f4 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java index ac59c321a..ebc4685ee 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java index 7fc11e39d..e422a9b40 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.image.observation; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.ImageMessage; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,7 +41,7 @@ class ImageModelPromptContentObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -51,7 +53,7 @@ class ImageModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -63,7 +65,7 @@ class ImageModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(AiObservationAttributes.PROMPT.value(), "[\"supercalifragilisticexpialidocious\"]")); @@ -77,7 +79,7 @@ class ImageModelPromptContentObservationFilterTests { .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); - var augmentedContext = observationFilter.map(originalContext); + var augmentedContext = this.observationFilter.map(originalContext); assertThat(augmentedContext.getHighCardinalityKeyValues()) .contains(KeyValue.of(AiObservationAttributes.PROMPT.value(), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java index ebeabb4fa..0e6353e82 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.metadata; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; +package org.springframework.ai.metadata; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + /** * Unit Tests for {@link PromptMetadata}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java index 18d7d6392..cac203678 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.metadata; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.metadata.Usage; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doReturn; @@ -23,9 +28,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.metadata.Usage; - /** * Unit Tests for {@link Usage}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java index 7b03f7e2d..b2d8c2d51 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/ModelOptionsUtilsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model; import java.util.Map; @@ -28,104 +29,6 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ public class ModelOptionsUtilsTests { - public static interface TestPortableOptions extends ModelOptions { - - String getName(); - - void setName(String name); - - Integer getAge(); - - void setAge(Integer age); - - } - - public static class TestPortableOptionsImpl implements TestPortableOptions { - - private String name; - - private Integer age; - - // Non interface fields - private String nonInterfaceField; - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public Integer getAge() { - return age; - } - - @Override - public void setAge(Integer age) { - this.age = age; - } - - public String getNonInterfaceField() { - return nonInterfaceField; - } - - public void setNonInterfaceField(String nonInterfaceField) { - this.nonInterfaceField = nonInterfaceField; - } - - } - - public static class TestSpecificOptions implements TestPortableOptions { - - @JsonProperty("specificField") - private String specificField; - - @JsonProperty("name") - private String name; - - @JsonProperty("age") - private Integer age; - - @Override - public String getName() { - return name; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public Integer getAge() { - return age; - } - - @Override - public void setAge(Integer age) { - this.age = age; - } - - public String getSpecificField() { - return specificField; - } - - public void setSpecificField(String modelSpecificField) { - this.specificField = modelSpecificField; - } - - @Override - public String toString() { - return "TestModelSpecificOptions{" + "specificField='" + specificField + '\'' + ", name='" + name + '\'' - + ", age=" + age + '}'; - } - - } - @Test public void merge() { TestPortableOptionsImpl portableOptions = new TestPortableOptionsImpl(); @@ -137,15 +40,16 @@ public class ModelOptionsUtilsTests { specificOptions.setName("Mike"); specificOptions.setSpecificField("SpecificField"); - assertThatThrownBy(() -> { - ModelOptionsUtils.merge(portableOptions, specificOptions, TestPortableOptionsImpl.class); - }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("No @JsonProperty fields found in the "); + assertThatThrownBy( + () -> ModelOptionsUtils.merge(portableOptions, specificOptions, TestPortableOptionsImpl.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("No @JsonProperty fields found in the "); var specificOptions2 = ModelOptionsUtils.merge(portableOptions, specificOptions, TestSpecificOptions.class); assertThat(specificOptions2.getAge()).isEqualTo(30); assertThat(specificOptions2.getName()).isEqualTo("John"); // !!! Overridden by the - // portableOptions + // portableOptions assertThat(specificOptions2.getSpecificField()).isEqualTo("SpecificField"); } @@ -221,9 +125,108 @@ public class ModelOptionsUtilsTests { @Test public void getJsonPropertyValues() { record TestRecord(@JsonProperty("field1") String fieldA, @JsonProperty("field2") String fieldB) { + } assertThat(ModelOptionsUtils.getJsonPropertyValues(TestRecord.class)).hasSize(2); assertThat(ModelOptionsUtils.getJsonPropertyValues(TestRecord.class)).containsExactly("field1", "field2"); } -} \ No newline at end of file + public interface TestPortableOptions extends ModelOptions { + + String getName(); + + void setName(String name); + + Integer getAge(); + + void setAge(Integer age); + + } + + public static class TestPortableOptionsImpl implements TestPortableOptions { + + private String name; + + private Integer age; + + // Non interface fields + private String nonInterfaceField; + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public Integer getAge() { + return this.age; + } + + @Override + public void setAge(Integer age) { + this.age = age; + } + + public String getNonInterfaceField() { + return this.nonInterfaceField; + } + + public void setNonInterfaceField(String nonInterfaceField) { + this.nonInterfaceField = nonInterfaceField; + } + + } + + public static class TestSpecificOptions implements TestPortableOptions { + + @JsonProperty("specificField") + private String specificField; + + @JsonProperty("name") + private String name; + + @JsonProperty("age") + private Integer age; + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public Integer getAge() { + return this.age; + } + + @Override + public void setAge(Integer age) { + this.age = age; + } + + public String getSpecificField() { + return this.specificField; + } + + public void setSpecificField(String modelSpecificField) { + this.specificField = modelSpecificField; + } + + @Override + public String toString() { + return "TestModelSpecificOptions{" + "specificField='" + this.specificField + '\'' + ", name='" + this.name + + '\'' + ", age=" + this.age + '}'; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java index b5ac63a52..69b60b355 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java index f4647be23..fb532d9ce 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.lang.reflect.Type; @@ -39,8 +40,8 @@ class TypeResolverHelperIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction" }) void beanInputTypeResolutionTest(String beanName) { - assertThat(applicationContext).isNotNull(); - Type beanType = FunctionContextUtils.findType(applicationContext.getBeanFactory(), beanName); + assertThat(this.applicationContext).isNotNull(); + Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName); assertThat(beanType).isNotNull(); Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); assertThat(functionInputType).isNotNull(); @@ -49,9 +50,11 @@ class TypeResolverHelperIT { } public record WeatherRequest(String city) { + } public record WeatherResponse(float temperatureInCelsius) { + } public static class Outer { @@ -70,17 +73,17 @@ class TypeResolverHelperIT { @SpringBootConfiguration public static class TypeResolverHelperConfiguration { - @Bean() + @Bean Outer.InnerWeatherFunction weatherClassDefinition() { return new Outer.InnerWeatherFunction(); } - @Bean() + @Bean Function weatherFunctionDefinition() { return new Outer.InnerWeatherFunction(); } - @Bean() + @Bean StandaloneWeatherFunction standaloneWeatherFunction() { return new StandaloneWeatherFunction(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java index 76622a222..8051fe991 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.function; import java.util.function.Function; @@ -27,7 +28,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Request; import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Response; -import static org.assertj.core.api.Assertions.assertThat;; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -64,6 +65,11 @@ public class TypeResolverHelperTests { public static class MockWeatherService implements Function { + @Override + public Response apply(Request request) { + return new Response(10, "C"); + } + /** * Weather Function request. */ @@ -75,14 +81,11 @@ public class TypeResolverHelperTests { @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") String unit) { + } public record Response(double temp, String unit) { - } - @Override - public Response apply(Request request) { - return new Response(10, "C"); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java index 7bb47f11c..ce69e7235 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelObservationContextTests.java @@ -1,6 +1,23 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java index 53949df14..3ef061335 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.model.observation; import io.micrometer.common.KeyValue; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; import org.springframework.ai.observation.conventions.AiObservationMetricNames; @@ -86,7 +88,7 @@ class ModelUsageMetricsGeneratorTests { private final Long totalTokens; - public TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { + TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { this.promptTokens = promptTokens; this.generationTokens = generationTokens; this.totalTokens = totalTokens; @@ -94,17 +96,17 @@ class ModelUsageMetricsGeneratorTests { @Override public Long getPromptTokens() { - return promptTokens; + return this.promptTokens; } @Override public Long getGenerationTokens() { - return generationTokens; + return this.generationTokens; } @Override public Long getTotalTokens() { - return totalTokens; + return this.totalTokens; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java index 59e822ddf..d538245b2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/observation/AiOperationMetadataTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.observation; import org.junit.jupiter.api.Test; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java index aa5848370..78ed778b2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/observation/tracing/TracingHelperTests.java @@ -1,17 +1,32 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.observation.tracing; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.Test; - import io.micrometer.tracing.Span; import io.micrometer.tracing.TraceContext; import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link TracingHelper}. @@ -125,4 +140,4 @@ class TracingHelperTests { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java index 836f3cfc8..711889d2c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/ChatTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.prompt; public class ChatTests { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index 835bd59e7..096d5e21c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.prompt; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.core.io.InputStreamResource; -import org.springframework.core.io.Resource; +package org.springframework.ai.prompt; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -33,6 +24,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; + import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -40,6 +42,18 @@ import static org.junit.jupiter.api.Assertions.assertThrows; public class PromptTemplateTest { + private static Map createTestMap() { + Map model = new HashMap<>(); + model.put("key1", "value1"); + model.put("key2", true); + return model; + } + + private static void assertEqualsWithNormalizedEOLs(String expected, String actual) { + assertEquals(expected.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator()), + actual.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator())); + } + @Test public void testCreateWithEmptyModelAndChatOptions() { String template = "This is a test prompt with no variables"; @@ -154,13 +168,6 @@ public class PromptTemplateTest { assertEquals(expected, result); } - private static Map createTestMap() { - Map model = new HashMap<>(); - model.put("key1", "value1"); - model.put("key2", true); - return model; - } - @Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate") @Test public void testRenderResourceAsValue() throws Exception { @@ -199,9 +206,4 @@ public class PromptTemplateTest { assertThrows(IllegalStateException.class, promptTemplate::render); } - private static void assertEqualsWithNormalizedEOLs(String expected, String actual) { - assertEquals(expected.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator()), - actual.replaceAll("\\r\\n|\\r|\\n", System.lineSeparator())); - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java index 62bb62b80..92bc458e1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.prompt; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; +package org.springframework.ai.prompt; import java.util.HashMap; import java.util.Map; import java.util.Set; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; + import static org.assertj.core.api.Assertions.assertThat; @SuppressWarnings("unchecked") @@ -40,9 +42,7 @@ class PromptTests { model.put("firstName", "Nick"); // Try to render with missing value for template variable, expect exception - Assertions.assertThatThrownBy(() -> { - String promptString = pt.render(model); - }) + Assertions.assertThatThrownBy(() -> pt.render(model)) .isInstanceOf(IllegalStateException.class) .hasMessage("Not all template variables were replaced. Missing variable names are [lastName]"); @@ -83,7 +83,7 @@ class PromptTests { Prompt systemPrompt = promptTemplate.create(systemModel); promptTemplate = new PromptTemplate(humanTemplate); // creates a Prompt with - // HumanMessage + // HumanMessage Prompt humanPrompt = promptTemplate.create(humanModel); // ChatPromptTemplate chatPromptTemplate = new ChatPromptTemplate(systemPrompt, @@ -125,9 +125,9 @@ class PromptTests { @Test void testBadFormatOfTemplateString() { String template = "This is a {foo test"; - Assertions.assertThatThrownBy(() -> { - new PromptTemplate(template); - }).isInstanceOf(IllegalArgumentException.class).hasMessage("The template string is not valid."); + Assertions.assertThatThrownBy(() -> new PromptTemplate(template)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The template string is not valid."); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java index b57bc99c3..af7aa19d7 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/reader/JsonReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; +import java.util.List; + import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest @@ -39,8 +41,8 @@ public class JsonReaderTests { @Test void loadJsonArray() { - assertThat(arrayResource).isNotNull(); - JsonReader jsonReader = new JsonReader(arrayResource, "description"); + assertThat(this.arrayResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.arrayResource, "description"); List documents = jsonReader.get(); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -50,8 +52,8 @@ public class JsonReaderTests { @Test void loadJsonObject() { - assertThat(ObjectResource).isNotNull(); - JsonReader jsonReader = new JsonReader(ObjectResource, "description"); + assertThat(this.ObjectResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.ObjectResource, "description"); List documents = jsonReader.get(); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -61,8 +63,8 @@ public class JsonReaderTests { @Test void loadJsonArrayFromPointer() { - assertThat(arrayResource).isNotNull(); - JsonReader jsonReader = new JsonReader(eventsResource, "description"); + assertThat(this.arrayResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.eventsResource, "description"); List documents = jsonReader.get("/0/sessions"); assertThat(documents).isNotEmpty(); for (Document document : documents) { @@ -73,8 +75,8 @@ public class JsonReaderTests { @Test void loadJsonObjectFromPointer() { - assertThat(ObjectResource).isNotNull(); - JsonReader jsonReader = new JsonReader(ObjectResource, "name"); + assertThat(this.ObjectResource).isNotNull(); + JsonReader jsonReader = new JsonReader(this.ObjectResource, "name"); List documents = jsonReader.get("/store"); assertThat(documents).isNotEmpty(); assertThat(documents.size()).isEqualTo(1); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java index 3db8952a4..835d17f77 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/reader/TextReaderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.reader; -import java.io.File; -import java.io.IOException; -import java.net.URI; -import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; @@ -28,7 +25,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.DefaultResourceLoader; -import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import static org.assertj.core.api.Assertions.assertThat; @@ -105,4 +101,4 @@ public class TextReaderTests { assertThat(customDocument.getContent()).isEqualTo("Another test content"); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java index a5caf706e..c3172380d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TextSplitterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.transformer.splitter; import java.util.ArrayList; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index 0baefc0ac..c30225b5c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -1,12 +1,29 @@ -package org.springframework.ai.transformer.splitter; +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.junit.jupiter.api.Test; -import org.springframework.ai.document.DefaultContentFormatter; -import org.springframework.ai.document.Document; +package org.springframework.ai.transformer.splitter; import java.util.List; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.document.DefaultContentFormatter; +import org.springframework.ai.document.Document; + import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java index 18d0b3424..12084d007 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -31,8 +32,8 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GT import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; /** * @author Christian Tzolov @@ -44,13 +45,14 @@ public class FilterExpressionBuilderTests { @Test public void testEQ() { // country == "BG" - assertThat(b.eq("country", "BG").build()).isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); + assertThat(this.b.eq("country", "BG").build()) + .isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - Expression exp = b.and(b.eq("genre", "drama"), b.gte("year", 2020)).build(); + Expression exp = this.b.and(this.b.eq("genre", "drama"), this.b.gte("year", 2020)).build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); } @@ -58,7 +60,7 @@ public class FilterExpressionBuilderTests { @Test public void testIn() { // genre in ["comedy", "documentary", "drama"] - var exp = b.in("genre", "comedy", "documentary", "drama").build(); + var exp = this.b.in("genre", "comedy", "documentary", "drama").build(); assertThat(exp) .isEqualTo(new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); } @@ -66,7 +68,9 @@ public class FilterExpressionBuilderTests { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - var exp = b.and(b.or(b.gte("year", 2020), b.eq("country", "BG")), b.ne("city", "Sofia")).build(); + var exp = this.b + .and(this.b.or(this.b.gte("year", 2020), this.b.eq("country", "BG")), this.b.ne("city", "Sofia")) + .build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), @@ -77,7 +81,9 @@ public class FilterExpressionBuilderTests { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - var exp = b.and(b.group(b.or(b.gte("year", 2020), b.eq("country", "BG"))), b.nin("city", "Sofia", "Plovdiv")) + var exp = this.b + .and(this.b.group(this.b.or(this.b.gte("year", 2020), this.b.eq("country", "BG"))), + this.b.nin("city", "Sofia", "Plovdiv")) .build(); assertThat(exp).isEqualTo(new Expression(AND, @@ -89,7 +95,10 @@ public class FilterExpressionBuilderTests { @Test public void tesIn2() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - var exp = b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US")).build(); + var exp = this.b + .and(this.b.and(this.b.eq("isOpen", true), this.b.gte("year", 2020)), + this.b.in("country", "BG", "NL", "US")) + .build(); assertThat(exp).isEqualTo(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), @@ -100,7 +109,8 @@ public class FilterExpressionBuilderTests { @Test public void tesNot() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - var exp = b.not(b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US"))) + var exp = this.b.not(this.b.and(this.b.and(this.b.eq("isOpen", true), this.b.gte("year", 2020)), + this.b.in("country", "BG", "NL", "US"))) .build(); assertThat(exp).isEqualTo(new Expression(NOT, diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java index 8253fb234..2c16705e2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -45,55 +46,55 @@ public class FilterExpressionTextParserTests { @Test public void testEQ() { // country == "BG" - Expression exp = parser.parse("country == 'BG'"); + Expression exp = this.parser.parse("country == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("country"), new Value("BG"))); - assertThat(parser.getCache().get("WHERE " + "country == 'BG'")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "country == 'BG'")).isEqualTo(exp); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - Expression exp = parser.parse("genre == 'drama' && year >= 2020"); + Expression exp = this.parser.parse("genre == 'drama' && year >= 2020"); assertThat(exp).isEqualTo(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); - assertThat(parser.getCache().get("WHERE " + "genre == 'drama' && year >= 2020")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "genre == 'drama' && year >= 2020")).isEqualTo(exp); } @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - Expression exp = parser.parse("genre in ['comedy', 'documentary', 'drama']"); + Expression exp = this.parser.parse("genre in ['comedy', 'documentary', 'drama']"); assertThat(exp) .isEqualTo(new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); - assertThat(parser.getCache().get("WHERE " + "genre in ['comedy', 'documentary', 'drama']")).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + "genre in ['comedy', 'documentary', 'drama']")).isEqualTo(exp); } @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - Expression exp = parser.parse("year >= 2020 OR country == \"BG\" AND city != \"Sofia\""); + Expression exp = this.parser.parse("year >= 2020 OR country == \"BG\" AND city != \"Sofia\""); assertThat(exp).isEqualTo(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); - assertThat(parser.getCache().get("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\"")) + assertThat(this.parser.getCache().get("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\"")) .isEqualTo(exp); } @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - Expression exp = parser.parse("(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]"); + Expression exp = this.parser.parse("(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]"); assertThat(exp).isEqualTo(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]")) .isEqualTo(exp); } @@ -101,20 +102,21 @@ public class FilterExpressionTextParserTests { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - Expression exp = parser.parse("isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]"); + Expression exp = this.parser.parse("isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]"); assertThat(exp).isEqualTo(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp); } @Test public void tesNot() { // NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]) - Expression exp = parser.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); + Expression exp = this.parser + .parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); assertThat(exp).isEqualTo(new Expression(NOT, new Group(new Expression(AND, @@ -123,7 +125,7 @@ public class FilterExpressionTextParserTests { new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))), null)); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])")) .isEqualTo(exp); } @@ -131,7 +133,7 @@ public class FilterExpressionTextParserTests { @Test public void tesNotNin() { // NOT(country NOT IN ["BG", "NL", "US"]) - Expression exp = parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); + Expression exp = this.parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); assertThat(exp).isEqualTo(new Expression(NOT, new Group(new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US")))), null)); @@ -140,7 +142,7 @@ public class FilterExpressionTextParserTests { @Test public void tesNotNin2() { // NOT country NOT IN ["BG", "NL", "US"] - Expression exp = parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); + Expression exp = this.parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); assertThat(exp).isEqualTo(new Expression(NOT, new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US"))), null)); @@ -149,7 +151,7 @@ public class FilterExpressionTextParserTests { @Test public void tesNestedNot() { // NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"])) - Expression exp = parser + Expression exp = this.parser .parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"); assertThat(exp).isEqualTo(new Expression(NOT, @@ -161,7 +163,7 @@ public class FilterExpressionTextParserTests { null))), null)); - assertThat(parser.getCache() + assertThat(this.parser.getCache() .get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))")) .isEqualTo(exp); } @@ -170,23 +172,23 @@ public class FilterExpressionTextParserTests { public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 String expText = "temperature >= -15.6 && temperature <= +20.13"; - Expression exp = parser.parse(expText); + Expression exp = this.parser.parse(expText); assertThat(exp).isEqualTo(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); - assertThat(parser.getCache().get("WHERE " + expText)).isEqualTo(exp); + assertThat(this.parser.getCache().get("WHERE " + expText)).isEqualTo(exp); } @Test public void testIdentifiers() { - Expression exp = parser.parse("'country.1' == 'BG'"); + Expression exp = this.parser.parse("'country.1' == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("'country.1'"), new Value("BG"))); - exp = parser.parse("'country_1_2_3' == 'BG'"); + exp = this.parser.parse("'country_1_2_3' == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("'country_1_2_3'"), new Value("BG"))); - exp = parser.parse("\"country 1 2 3\" == 'BG'"); + exp = this.parser.parse("\"country 1 2 3\" == 'BG'"); assertThat(exp).isEqualTo(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java index 472e9b1d8..df793ecf0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -165,6 +166,6 @@ public class FilterHelperTests { } } - }; + } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java index 5535766b7..acdd0b3ed 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import org.junit.jupiter.api.Test; @@ -67,7 +68,7 @@ public class SearchRequestTests { assertThat(emptyRequest.getQuery()).isEqualTo("New Query"); } - @Test() + @Test public void withSimilarityThreshold() { var request = SearchRequest.query("Test").withSimilarityThreshold(0.678); assertThat(request.getSimilarityThreshold()).isEqualTo(0.678); @@ -75,19 +76,15 @@ public class SearchRequestTests { request.withSimilarityThreshold(0.9); assertThat(request.getSimilarityThreshold()).isEqualTo(0.9); - assertThatThrownBy(() -> { - request.withSimilarityThreshold(-1); - }).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> request.withSimilarityThreshold(-1)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Similarity threshold must be in [0,1] range."); - assertThatThrownBy(() -> { - request.withSimilarityThreshold(1.1); - }).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> request.withSimilarityThreshold(1.1)).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Similarity threshold must be in [0,1] range."); } - @Test() + @Test public void withTopK() { var request = SearchRequest.query("Test").withTopK(66); assertThat(request.getTopK()).isEqualTo(66); @@ -95,13 +92,12 @@ public class SearchRequestTests { request.withTopK(89); assertThat(request.getTopK()).isEqualTo(89); - assertThatThrownBy(() -> { - request.withTopK(-1); - }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("TopK should be positive."); + assertThatThrownBy(() -> request.withTopK(-1)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("TopK should be positive."); } - @Test() + @Test public void withFilterExpression() { var request = SearchRequest.query("Test").withFilterExpression("country == 'BG' && year >= 2022"); @@ -128,9 +124,8 @@ public class SearchRequestTests { assertThat(request.getFilterExpression()).isNull(); assertThat(request.hasFilterExpression()).isFalse(); - assertThatThrownBy(() -> { - request.withFilterExpression("FooBar"); - }).isInstanceOf(FilterExpressionParseException.class) + assertThatThrownBy(() -> request.withFilterExpression("FooBar")) + .isInstanceOf(FilterExpressionParseException.class) .hasMessageContaining("Error: no viable alternative at input 'FooBar'"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java index e86b927f0..9fc858aa1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/converter/PineconeFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter.converter; import java.util.List; @@ -45,14 +46,14 @@ public class PineconeFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country\": {\"$eq\": \"BG\"}}"); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr) @@ -62,7 +63,7 @@ public class PineconeFilterExpressionConverterTests { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("{\"genre\": {\"$in\": [\"comedy\",\"documentary\",\"drama\"]}}"); } @@ -70,7 +71,7 @@ public class PineconeFilterExpressionConverterTests { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -81,7 +82,7 @@ public class PineconeFilterExpressionConverterTests { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -92,7 +93,7 @@ public class PineconeFilterExpressionConverterTests { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -104,7 +105,7 @@ public class PineconeFilterExpressionConverterTests { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -114,11 +115,11 @@ public class PineconeFilterExpressionConverterTests { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country 1 2 3\": {\"$eq\": \"BG\"}}"); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"country 1 2 3\": {\"$eq\": \"BG\"}}"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java index 5882f9285..981ac04b1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/DefaultVectorStoreObservationConventionTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link DefaultVectorStoreObservationConvention}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java index 06d543151..6f6abd873 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreObservationContextTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import org.junit.jupiter.api.Test; - /** * Unit tests for {@link VectorStoreObservationContext}. * diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java index 652c28862..ba7a37e05 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationFilterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.observation; import java.util.List; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link VectorStoreQueryResponseObservationFilter}. @@ -39,7 +40,7 @@ class VectorStoreQueryResponseObservationFilterTests { @Test void whenNotSupportedObservationContextThenReturnOriginalContext() { var expectedContext = new Observation.Context(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -49,7 +50,7 @@ class VectorStoreQueryResponseObservationFilterTests { var expectedContext = VectorStoreObservationContext.builder("db", VectorStoreObservationContext.Operation.ADD) .build(); - var actualContext = observationFilter.map(expectedContext); + var actualContext = this.observationFilter.map(expectedContext); assertThat(actualContext).isEqualTo(expectedContext); } @@ -63,7 +64,7 @@ class VectorStoreQueryResponseObservationFilterTests { expectedContext.setQueryResponse(queryResponseDocs); - var augmentedContext = observationFilter.map(expectedContext); + var augmentedContext = this.observationFilter.map(expectedContext); assertThat(augmentedContext.getHighCardinalityKeyValues()).contains(KeyValue .of(HighCardinalityKeyNames.DB_VECTOR_QUERY_RESPONSE_DOCUMENTS.asString(), "[\"doc1\", \"doc2\"]")); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java index 657f5555f..499c3cc02 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/observation/VectorStoreQueryResponseObservationHandlerTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.observation; +import java.util.List; + import io.micrometer.tracing.handler.TracingObservationHandler; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; @@ -22,13 +25,12 @@ import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.sdk.trace.ReadableSpan; import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.observation.conventions.VectorStoreObservationAttributes; import org.springframework.ai.observation.conventions.VectorStoreObservationEventNames; import org.springframework.ai.observation.tracing.TracingHelper; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** diff --git a/spring-ai-core/src/test/resources/application-logging-test.properties b/spring-ai-core/src/test/resources/application-logging-test.properties index 8ba46b8d7..8c5bc06a7 100644 --- a/spring-ai-core/src/test/resources/application-logging-test.properties +++ b/spring-ai-core/src/test/resources/application-logging-test.properties @@ -1,2 +1,17 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# logging.level.org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor=DEBUG logging.level.ch.qos.logback=ERROR diff --git a/spring-ai-core/src/test/resources/bikes.json b/spring-ai-core/src/test/resources/bikes.json index 486597515..62a75ebed 100644 --- a/spring-ai-core/src/test/resources/bikes.json +++ b/spring-ai-core/src/test/resources/bikes.json @@ -1,265 +1,265 @@ [ - { - "name": "E-Adrenaline 8.0 EX1", - "shortDescription": "a versatile and comfortable e-MTB designed for adrenaline enthusiasts who want to explore all types of terrain. It features a powerful motor and advanced suspension to provide a smooth and responsive ride, with a variety of customizable settings to fit any rider's needs.", - "description": "## Overview\r\nIt's right for you if...\r\nYou want to push your limits on challenging trails and terrain, with the added benefit of an electric assist to help you conquer steep climbs and rough terrain. You also want a bike with a comfortable and customizable fit, loaded with high-quality components and technology.\r\n\r\nThe tech you get\r\nA lightweight, full ADV Mountain Carbon frame with a customizable geometry, including an adjustable head tube and chainstay length. A powerful and efficient motor with a 375Wh battery that can assist up to 28 mph when it's on, and provides a smooth and seamless transition when it's off. A SRAM EX1 8-speed drivetrain, a RockShox Lyrik Ultimate fork, and a RockShox Super Deluxe Ultimate rear shock.\r\n\r\nThe final word\r\nOur E-Adrenaline 8.0 EX1 is the perfect bike for adrenaline enthusiasts who want to explore all types of terrain. It's versatile, comfortable, and loaded with advanced technology to provide a smooth and responsive ride, no matter where your adventures take you.\r\n\r\n\r\n## Features\r\nVersatile and customizable\r\nThe E-Adrenaline 8.0 EX1 features a customizable geometry, including an adjustable head tube and chainstay length, so you can fine-tune your ride to fit your needs and preferences. It also features a variety of customizable settings, including suspension tuning, motor assistance levels, and more.\r\n\r\nPowerful and efficient\r\nThe bike is equipped with a powerful and efficient motor that provides a smooth and seamless transition between human power and electric assist. It can assist up to 28 mph when it's on, and provides zero drag when it's off.\r\n\r\nAdvanced suspension\r\nThe E-Adrenaline 8.0 EX1 features a RockShox Lyrik Ultimate fork and a RockShox Super Deluxe Ultimate rear shock, providing advanced suspension technology to absorb shocks and bumps on any terrain. The suspension is also customizable to fit your riding style and preferences.\r\n\r\n\r\n## Specs\r\nFrameset\r\nFrame ADV Mountain Carbon main frame & stays, adjustable head tube and chainstay length, tapered head tube, Knock Block, Control Freak internal routing, Boost148, 150mm travel\r\nFork RockShox Lyrik Ultimate, DebonAir spring, Charger 2.1 RC2 damper, remote lockout, tapered steerer, 42mm offset, Boost110, 15mm Maxle Stealth, 160mm travel\r\nShock RockShox Super Deluxe Ultimate, DebonAir spring, Thru Shaft 3-position damper, 230x57.5mm\r\n\r\nWheels\r\nWheel front Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 6-bolt, Boost110, 15mm thru axle\r\nWheel rear Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 54T Rapid Drive, 6-bolt, Shimano MicroSpline freehub, Boost148, 12mm thru axle\r\nSkewer rear Bontrager Switch thru axle, removable lever\r\nTire Bontrager XR5 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.50''\r\nTire part Bontrager TLR sealant, 6oz\r\n\r\nDrivetrain\r\nShifter SRAM EX1, 8 speed\r\nRear derailleur SRAM EX1, 8 speed\r\nCrank Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nChainring SRAM EX1, 18T, steel\r\nCassette SRAM EX1, 11-48, 8 speed\r\nChain SRAM EX1, 8 speed\r\n\r\nComponents\r\nSaddle Bontrager Arvada, hollow chromoly rails, 138mm width\r\nSeatpost Bontrager Line Elite Dropper, internal routing, 31.6mm\r\nHandlebar Bontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\r\nGrips Bontrager XR Trail Elite, alloy lock-on\r\nStem Bontrager Line Pro, 35mm, Knock Block, Blendr compatible, 0 degree, 50mm length\r\nHeadset Knock Block Integrated, 62-degree radius, cartridge bearing, 1-1\/8'' top, 1.5'' bottom\r\nBrake SRAM G2 RSC hydraulic disc, carbon levers\r\nBrake rotor SRAM Centerline, centerlock, round edge, 200mm\r\n\r\nAccessories\r\nE-bike system Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nBattery Bosch PowerTube 625, 625Wh\r\nCharger Bosch 4A standard charger\r\nController Bosch Kiox with Anti-theft solution, Bluetooth connectivity, 1.9'' display\r\nTool Bontrager Switch thru axle, removable lever\r\n\r\nWeight\r\nWeight M - 20.25 kg \/ 44.6 lbs (with TLR sealant, no tubes)\r\nWeight limit This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\r\n\r\n## Sizing & fit\r\n\r\n| Size | Rider Height | Inseam |\r\n|:----:|:------------------------:|:--------------------:|\r\n| S | 155 - 170 cm 5'1\" - 5'7\" | 73 - 80 cm 29\" - 31.5\" |\r\n| M | 163 - 178 cm 5'4\" - 5'10\" | 77 - 83 cm 30.5\" - 32.5\" |\r\n| L | 176 - 191 cm 5'9\" - 6'3\" | 83 - 89 cm 32.5\" - 35\" |\r\n| XL | 188 - 198 cm 6'2\" - 6'6\" | 88 - 93 cm 34.5\" - 36.5\" |\r\n\r\n\r\n## Geometry\r\n\r\nAll measurements provided in cm unless otherwise noted.\r\nSizing table\r\n| Frame size letter | S | M | L | XL |\r\n|---------------------------|-------|-------|-------|-------|\r\n| Actual frame size | 15.8 | 17.8 | 19.8 | 21.8 |\r\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\r\n| A \u2014 Seat tube | 40.0 | 42.5 | 47.5 | 51.0 |\r\n| B \u2014 Seat tube angle | 72.5\u00B0 | 72.8\u00B0 | 73.0\u00B0 | 73.0\u00B0 |\r\n| C \u2014 Head tube length | 9.5 | 10.5 | 11.0 | 11.5 |\r\n| D \u2014 Head angle | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 |\r\n| E \u2014 Effective top tube | 59.0 | 62.0 | 65.0 | 68.0 |\r\n| F \u2014 Bottom bracket height | 32.5 | 32.5 | 32.5 | 32.5 |\r\n| G \u2014 Bottom bracket drop | 5.5 | 5.5 | 5.5 | 5.5 |\r\n| H \u2014 Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\r\n| I \u2014 Offset | 4.5 | 4.5 | 4.5 | 4.5 |\r\n| J \u2014 Trail | 11.0 | 11.0 | 11.0 | 11.0 |\r\n| K \u2014 Wheelbase | 113.0 | 117.0 | 120.0 | 123.0 |\r\n| L \u2014 Standover | 77.0 | 77.0 | 77.0 | 77.0 |\r\n| M \u2014 Frame reach | 41.0 | 44.5 | 47.5 | 50.0 |\r\n| N \u2014 Frame stack | 61.0 | 62.0 | 62.5 | 63.0 |", - "price": 1499.99, - "tags": [ - "bicycle" - ] - }, - { - "name": "Enduro X Pro", - "shortDescription": "The Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame and top-of-the-line components, this bike is ready to tackle any trail, from technical downhill descents to grueling uphill climbs.", - "text": "## Overview\nIt's right for you if...\nYou're an experienced mountain biker who wants a high-performance bike that can handle any terrain. You want a bike with the best components available, including a full carbon frame, suspension system, and hydraulic disc brakes.\n\nThe tech you get\nOur top-of-the-line full carbon frame with aggressive geometry and a slack head angle for maximum control. It's equipped with a Fox Factory suspension system with 170mm of travel in the front and 160mm in the rear, a Shimano XTR 12-speed drivetrain, and hydraulic disc brakes for maximum stopping power. The bike also features a dropper seatpost for easy adjustments on the fly.\n\nThe final word\nThe Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame, top-of-the-line components, and aggressive geometry, this bike is ready to take on any trail. Whether you're a seasoned pro or just starting out, the Enduro X Pro will help you take your riding to the next level.\n\n## Features\nFull carbon frame\nAggressive geometry with a slack head angle\nFox Factory suspension system with 170mm of travel in the front and 160mm in the rear\nShimano XTR 12-speed drivetrain\nHydraulic disc brakes for maximum stopping power\nDropper seatpost for easy adjustments on the fly\n\n## Specifications\nFrameset\nFrame\tFull carbon frame\nFork\tFox Factory suspension system with 170mm of travel\nRear suspension\tFox Factory suspension system with 160mm of travel\n\nWheels\nWheel size\t27.5\" or 29\"\nTires\tTubeless-ready Maxxis tires\n\nDrivetrain\nShifters\tShimano XTR 12-speed\nFront derailleur\tN/A\nRear derailleur\tShimano XTR\nCrankset\tShimano XTR\nCassette\tShimano XTR 12-speed\nChain\tShimano XTR\n\nComponents\nBrakes\tHydraulic disc brakes\nHandlebar\tAlloy handlebar\nStem\tAlloy stem\nSeatpost\tDropper seatpost\n\nAccessories\nPedals\tNot included\n\nWeight\nWeight\tApproximately 27-29 lbs\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 5'4\" - 5'8\" (162-172cm) |\n| M | 5'8\" - 5'11\" (172-180cm) |\n| L | 5'11\" - 6'3\" (180-191cm) |\n| XL | 6'3\" - 6'6\" (191-198cm) |\n\n## Geometry\n| Size | S | M | L | XL |\n|:----:|:---------------:|:---------------:|:-----------------:|:---------------:|\n| A - Seat tube length | 390mm | 425mm | 460mm | 495mm |\n| B - Effective top tube length | 585mm | 610mm | 635mm | 660mm |\n| C - Head tube angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| D - Seat tube angle | 76° | 76° | 76° | 76° |\n| E - Chainstay length | 435mm | 435mm | 435mm | 435mm |\n| F - Head tube length | 100mm | 110mm | 120mm | 130mm |\n| G - BB drop | 20mm | 20mm | 20mm | 20mm |\n| H - Wheelbase | 1155mm | 1180mm | 1205mm | 1230mm |\n| I - Standover height | 780mm | 800mm | 820mm | 840mm |\n| J - Reach | 425mm | 450mm | 475mm | 500mm |\n| K - Stack | 610mm | 620mm | 630mm | 640mm |", - "price": 599.99, - "tags": [ - "bicycle" - ] - }, - { - "name": "Blaze X1", - "shortDescription": "Blaze X1 is a high-performance road bike that offers superior speed and agility, making it perfect for competitive racing or fast-paced group rides. The bike features a lightweight carbon frame, aerodynamic tube shapes, a 12-speed Shimano Ultegra drivetrain, and hydraulic disc brakes for precise stopping power. With its sleek design and cutting-edge technology, Blaze X1 is a bike that is built to perform and dominate on any road.", - "description": "## Overview\nIt's right for you if...\nYou're a competitive road cyclist or an enthusiast who enjoys fast-paced group rides. You want a bike that is lightweight, agile, and delivers exceptional speed.\n\nThe tech you get\nBlaze X1 features a lightweight carbon frame with a tapered head tube and aerodynamic tube shapes for maximum speed and efficiency. The bike is equipped with a 12-speed Shimano Ultegra drivetrain for smooth and precise shifting, Shimano hydraulic disc brakes for powerful and reliable stopping power, and Bontrager Aeolus Elite 35 carbon wheels for increased speed and agility.\n\nThe final word\nBlaze X1 is a high-performance road bike that is designed to deliver exceptional speed and agility. With its cutting-edge technology and top-of-the-line components, it's a bike that is built to perform and dominate on any road.\n\n## Features\nSpeed and efficiency\nBlaze X1's lightweight carbon frame and aerodynamic tube shapes offer maximum speed and efficiency, allowing you to ride faster and farther with ease.\n\nPrecision stopping power\nShimano hydraulic disc brakes provide precise and reliable stopping power, even in wet or muddy conditions.\n\nAgility and control\nBontrager Aeolus Elite 35 carbon wheels make Blaze X1 incredibly agile and responsive, allowing you to navigate tight turns and corners with ease.\n\nSmooth and precise shifting\nThe 12-speed Shimano Ultegra drivetrain offers smooth and precise shifting, so you can easily find the right gear for any terrain.\n\n## Specifications\nFrameset\nFrame\tADV Carbon, tapered head tube, BB90, direct mount rim brakes, internal cable routing, DuoTrap S compatible, 130x9mm QR\nFork\tADV Carbon, tapered steerer, direct mount rim brakes, internal brake routing, 100x9mm QR\n\nWheels\nWheel front\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x9mm QR\nWheel rear\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11-speed freehub, 130x9mm QR\nTire front\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nTire rear\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nMax tire size\t25c Bontrager tires (with at least 4mm of clearance to frame)\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 12 speed\nFront derailleur\tShimano Ultegra R8000, braze-on\nRear derailleur\tShimano Ultegra R8000, short cage, 30T max cog\nCrank\tSize: 50, 52, 54\nShimano Ultegra R8000, 50/34 (compact), 170mm length\nSize: 56, 58, 60, 62\nShimano Ultegra R8000, 50/34 (compact), 172.5mm length\nBottom bracket\tBB90, Shimano press-fit\nCassette\tShimano Ultegra R8000, 11-30, 12 speed\nChain\tShimano Ultegra HG701, 12 speed\n\nComponents\nSaddle\tBontrager Montrose Elite, titanium rails, 138mm width\nSeatpost\tBontrager carbon seatmast cap, 20mm offset\nHandlebar\tBontrager Elite Aero VR-CF, alloy, 31.8mm, internal cable routing, 40cm width\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Elite, 31.8mm, Blendr-compatible, 7 degree, 80mm length\nBrake Shimano Ultegra hydraulic disc brake\n\nWeight\nWeight\t56 - 8.91 kg / 19.63 lbs (with tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider height |\n|------|-------------|\n| 50 | 162-166cm |\n| 52 | 165-170cm |\n| 54 | 168-174cm |\n| 56 | 174-180cm |\n| 58 | 179-184cm |\n| 60 | 184-189cm |\n| 62 | 189-196cm |\n\n## Geometry\n| Frame size | 50cm | 52cm | 54cm | 56cm | 58cm | 60cm | 62cm |\n|------------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A - Seat tube | 443mm | 460mm | 478mm | 500mm | 520mm | 540mm | 560mm |\n| B - Seat tube angle | 74.1° | 73.9° | 73.7° | 73.4° | 73.2° | 73.0° | 72.8° |\n| C - Head tube length | 100mm | 110mm | 130mm | 150mm | 170mm | 190mm | 210mm |\n| D - Head angle | 71.4° | 72.0° | 72.5° | 73.0° | 73.3° | 73.6° | 73.8° |\n| E - Effective top tube | 522mm | 535mm | 547mm | 562mm | 577mm | 593mm | 610mm |\n| F - Bottom bracket height | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm |\n| G - Bottom bracket drop | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm |\n| H - Chainstay length | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm |\n| I - Offset | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm |\n| J - Trail | 65mm | 62mm | 59mm | 56mm | 55mm | 53mm | 52mm |\n| K - Wheelbase | 983mm | 983mm | 990mm | 1005mm | 1019mm | 1036mm | 1055mm |\n| L - Standover | 741mm | 765mm | 787mm | 806mm | 825mm | 847mm | 869mm |", - "price": 799.99, - "tags": [ - "bicycle", - "mountain bike" - ] - }, - { - "name": "Celerity X5", - "shortDescription": "Celerity X5 is a versatile and reliable road bike that is designed for experienced and amateur riders alike. It's designed to provide smooth and comfortable rides over long distances. With an ultra-lightweight and responsive carbon fiber frame, Shimano 105 groupset, hydraulic disc brakes, and 28mm wide tires, this bike ensures efficient power transfer, precise handling, and superior stopping power.", - "description": "## Overview\n\nIt's right for you if... \nYou are looking for a high-performance road bike that offers a perfect balance of speed, comfort, and control. You enjoy long-distance rides and need a bike that is designed to handle various road conditions with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nCelerity X5 is equipped with a full carbon fiber frame that ensures maximum strength and durability while keeping the weight down. It features a Shimano 105 groupset with 11-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power, and 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that offers comfort, speed, and control, Celerity X5 is the perfect choice. With its lightweight carbon fiber frame, reliable components, and advanced technology, this bike is designed to help you enjoy long-distance rides with ease.\n\n## Features \n\nLightweight and responsive \nCelerity X5 comes with a full carbon fiber frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon seat post provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tCelerity X5 Full Carbon Fiber Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tCelerity X5 Full Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tCelerity X5 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano 105 R7025 Hydraulic Disc Shifters \nFront Derailleur\tShimano 105 R7000 \nRear Derailleur\tShimano 105 R7000 \nCrankset\tShimano 105 R7000 50-34T \nBottom Bracket\tShimano BB72-41B \nCassette\tShimano 105 R7000 11-30T \nChain\tShimano HG601 11-Speed Chain \n\nComponents \nSaddle\tSelle Royal Asphalt Saddle \nSeatpost\tCelerity X5 Carbon Seatpost \nHandlebar\tCelerity X5 Compact Handlebar \nStem\tCelerity X5 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano 105 R7025 Hydraulic Disc Brakes \nRotors\tShimano SM-RT70 160mm Rotors \n\nAccessories \nPedals\tCelerity X5 Road Pedals \n\nWeight \nWeight\t8.2 kg / 18.1 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", - "price": 399.99, - "tags": [ - "bicycle", - "city bike" - ] - }, - { - "name": "Velocity V8", - "shortDescription": "Velocity V8 is a high-performance road bike that is designed to deliver speed, agility, and control on the road. With its lightweight aluminum frame, carbon fiber fork, Shimano Tiagra groupset, and hydraulic disc brakes, this bike is perfect for experienced riders who are looking for a fast and responsive bike that can handle various road conditions.", - "description": "## Overview\n\nIt's right for you if... \nYou are an experienced rider who is looking for a high-performance road bike that is lightweight, agile, and responsive. You want a bike that can handle long-distance rides, steep climbs, and fast descents with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nVelocity V8 features a lightweight aluminum frame with a carbon fiber fork that ensures a comfortable ride without sacrificing stiffness and power transfer. It comes with a Shimano Tiagra groupset with 10-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power in all weather conditions, while 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that is lightweight, fast, and responsive, Velocity V8 is the perfect choice. With its lightweight aluminum frame, reliable components, and advanced technology, this bike is designed to help you enjoy fast and comfortable rides on the road.\n\n## Features \n\nLightweight and responsive \nVelocity V8 comes with a lightweight aluminum frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon fork provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tVelocity V8 Aluminum Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tVelocity V8 Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tVelocity V8 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano Tiagra Hydraulic Disc Shifters \nFront Derailleur\tShimano Tiagra \nRear Derailleur\tShimano Tiagra \nCrankset\tShimano Tiagra 50-34T \nBottom Bracket\tShimano BB-RS500-PB \nCassette\tShimano Tiagra 11-32T \nChain\tShimano HG54 10-Speed Chain \n\nComponents \nSaddle\tVelocity V8 Saddle \nSeatpost\tVelocity V8 Aluminum Seatpost \nHandlebar\tVelocity V8 Compact Handlebar \nStem\tVelocity V8 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano Tiagra Hydraulic Disc Brakes \nRotors\tShimano SM-RT64 160mm Rotors \n\nAccessories \nPedals\tVelocity V8 Road Pedals \n\nWeight \nWeight\t9.4 kg / 20.7 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", - "price": 1899.99, - "tags": [ - "bicycle", - "electric bike" - ] - }, - { - "name": "VeloCore X9 eMTB", - "shortDescription": "The VeloCore X9 eMTB is a light, agile and versatile electric mountain bike designed for adventure and performance. Its purpose-built frame and premium components offer an exhilarating ride experience on both technical terrain and smooth singletrack.", - "description": "## Overview\nIt's right for you if...\nYou love exploring new trails and testing your limits on challenging terrain. You want an electric mountain bike that offers power when you need it, without sacrificing performance or agility. You're looking for a high-quality bike with top-notch components and a sleek design.\n\nThe tech you get\nA lightweight, full carbon frame with custom geometry, a 140mm RockShox Pike Ultimate fork with Charger 2.1 damper, and a Fox Float DPS Performance shock. A Shimano STEPS E8000 motor and 504Wh battery that provide up to 62 miles of range and 20 mph assistance. A Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels.\n\nThe final word\nThe VeloCore X9 eMTB delivers power and agility in equal measure. It's a versatile and capable electric mountain bike that can handle any trail with ease. With premium components, a custom carbon frame, and a sleek design, this bike is built for adventure.\n\n## Features\nAgile and responsive\n\nThe VeloCore X9 eMTB is designed to be nimble and responsive on the trail. Its custom carbon frame offers a perfect balance of stiffness and compliance, while the suspension system provides smooth and stable performance on technical terrain.\n\nPowerful and efficient\n\nThe Shimano STEPS E8000 motor and 504Wh battery provide up to 62 miles of range and 20 mph assistance. The motor delivers smooth and powerful performance, while the battery offers reliable and consistent power for long rides.\n\nCustomizable ride experience\n\nThe VeloCore X9 eMTB comes with an intuitive and customizable Shimano STEPS display that allows you to adjust the level of assistance, monitor your speed and battery life, and customize your ride experience to suit your needs.\n\nPremium components\n\nThe VeloCore X9 eMTB is equipped with high-end components, including a Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels. These components offer reliable and precise performance, allowing you to push your limits with confidence.\n\n## Specs\nFrameset\nFrame\tVeloCore carbon fiber frame, Boost, tapered head tube, internal cable routing, 140mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 damper, DebonAir spring, 15x110mm Boost Maxle Ultimate, 46mm offset, 140mm travel\nShock\tFox Float DPS Performance, EVOL, 3-position adjust, Kashima Coat, 210x50mm\n\nWheels\nWheel front\tDT Swiss XM1700 Spline, 30mm internal width, 15x110mm Boost axle\nWheel rear\tDT Swiss XM1700 Spline, 30mm internal width, Shimano Microspline driver, 12x148mm Boost axle\nTire front\tMaxxis Minion DHF, 29x2.5\", EXO+ casing, tubeless ready\nTire rear\tMaxxis Minion DHR II, 29x2.4\", EXO+ casing, tubeless ready\n\nDrivetrain\nShifter\tShimano XT M8100, 12-speed\nRear derailleur\tShimano XT M8100, Shadow Plus, long cage, 51T max cog\nCrankset\tShimano STEPS E8000, 165mm length, 34T chainring\nCassette\tShimano XT M8100, 10-51T, 12-speed\nChain\tShimano CN-M8100, 12-speed\nPedals\tNot included\n\nComponents\nSaddle\tBontrager Arvada, hollow chromoly rails\nSeatpost\tDrop Line, internal routing, 31.6mm (15.5: 100mm, 17.5 & 18.5: 125mm, 19.5 & 21.5: 150mm)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nStem\tBontrager Line Pro, 35mm, Knock Block, 0 degree, 50mm length\nGrips\tBontrager XR Trail Elite, alloy lock-on\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrakeset\tShimano SLX M7120, 4-piston hydraulic disc\n\nAccessories\nBattery\tShimano STEPS BT-E8010, 504Wh\nCharger\tShimano STEPS EC-E8004, 4A\nController\tShimano STEPS E8000 display\nBike weight\tM - 22.5 kg / 49.6 lbs (with tubes)\n\n## Sizing & fit\n\n| Size | Rider Height |\n|:----:|:------------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" |\n| M | 170 - 178 cm 5'7\" - 5'10\"|\n| L | 178 - 186 cm 5'10\" - 6'1\"|\n| XL | 186 - 196 cm 6'1\" - 6'5\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| A — Seat tube | 40.6 | 43.2 | 47.0 | 51.0 |\n| B — Seat tube angle | 75.0° | 75.0° | 75.0° | 75.0° |\n| C — Head tube length | 9.6 | 10.6 | 11.6 | 12.6 |\n| D — Head angle | 66.5° | 66.5° | 66.5° | 66.5° |\n| E — Effective top tube | 60.4 | 62.6 | 64.8 | 66.9 |\n| F — Bottom bracket height | 33.2 | 33.2 | 33.2 | 33.2 |\n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 |\n| H — Chainstay length | 45.5 | 45.5 | 45.5 | 45.5 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 11.9 | 11.9 | 11.9 | 11.9 |\n| K — Wheelbase | 117.0 | 119.3 | 121.6 | 123.9 |\n| L — Standover | 75.9 | 75.9 | 78.6 | 78.6 |\n| M — Frame reach | 43.6 | 45.6 | 47.6 | 49.6 |\n| N — Frame stack | 60.5 | 61.5 | 62.4 | 63.4 |", - "price": 1299.99, - "tags": [ - "bicycle", - "touring bike" - ] - }, - { - "name": "Zephyr 8.8 GX Eagle AXS Gen 3", - "shortDescription": "Zephyr 8.8 GX Eagle AXS is a light and nimble full-suspension mountain bike. It's designed to handle technical terrain with ease and has a smooth and efficient ride feel. The sleek and powerful Bosch Performance Line CX motor and removable Powertube battery provide a boost to your pedaling and give you long-lasting riding time. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.", - "description": "## Overview\nIt's right for you if...\nYou're an avid mountain biker looking for a high-performance e-MTB that can tackle challenging trails. You want a bike with a powerful motor, efficient suspension, and advanced technology to enhance your riding experience. You also need a bike that's reliable and durable for long-lasting use.\n\nThe tech you get\nA lightweight, full carbon frame with 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. A Bosch Performance Line CX motor and removable Powertube 625Wh battery that can assist up to 20mph when it's on and gives zero drag when it's off, plus an easy-to-use handlebar-mounted Bosch Purion controller. A SRAM GX Eagle AXS wireless electronic drivetrain, a RockShox Reverb Stealth dropper, and DT Swiss HX1501 Spline One wheels.\n\nThe final word\nZephyr 8.8 GX Eagle AXS is a high-performance e-MTB that's designed to handle technical terrain with ease. With a powerful Bosch motor and long-lasting battery, you can conquer challenging climbs and enjoy long rides. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.\n\n## Features\nPowerful motor\n\nThe Bosch Performance Line CX motor provides a boost to your pedaling and can assist up to 20mph. It has four power modes and a walk-assist function for easy navigation on steep climbs. The motor is also reliable and durable for long-lasting use.\n\nEfficient suspension\n\nZephyr 8.8 has a 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. The suspension is efficient and responsive, allowing you to handle technical terrain with ease.\n\nRemovable battery\n\nThe Powertube 625Wh battery is removable for easy charging and storage. It provides long-lasting riding time and can be replaced with a spare battery for even longer rides. The battery is also durable and weather-resistant for all-season riding.\n\nAdvanced technology\n\nZephyr 8.8 is equipped with advanced technology, including a Bosch Purion controller for easy motor control, a SRAM GX Eagle AXS wireless electronic drivetrain for precise shifting, and a RockShox Reverb Stealth dropper for adjustable saddle height. The bike also has DT Swiss HX1501 Spline One wheels for reliable performance on any terrain.\n\nCarbon frame\n\nThe full carbon frame is lightweight and durable, providing a smooth and efficient ride. It's also designed with a tapered head tube, internal cable routing, and Boost148 spacing for enhanced stiffness and responsiveness.\n\n## Specs\nFrameset\nFrame\tCarbon main frame & stays, tapered head tube, internal routing, Boost148, 150mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 RCT3 damper, DebonAir spring, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 160mm travel\nShock\tRockShox Deluxe RT3, DebonAir spring, 205mm x 57.5mm\nMax compatible fork travel\t170mm\n\nWheels\nWheel front\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, 110x15mm Boost\nWheel rear\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, SRAM XD driver, 148x12mm Boost\nTire\tBontrager XR4 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.40''\nMax tire size\t29x2.60\"\n\nDrivetrain\nShifter\tSRAM GX Eagle AXS, wireless, 12 speed\nRear derailleur\tSRAM GX Eagle AXS\nCrank\tBosch Gen 4, 32T\nChainring\tSRAM X-Sync 2, 32T, direct-mount\nCassette\tSRAM PG-1275 Eagle, 10-52, 12 speed\nChain\tSRAM GX Eagle, 12 speed\n\nComponents\nSaddle\tBontrager Arvada, hollow titanium rails, 138mm width\nSeatpost\tRockShox Reverb Stealth, 31.6mm, internal routing, 150mm (S), 170mm (M/L), 200mm (XL)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nGrips\tBontrager XR Trail Elite, alloy lock-on\nStem\tBontrager Line Pro, Knock Block, 35mm, 0 degree, 50mm length\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake\tSRAM Code RSC hydraulic disc, 200mm (front), 180mm (rear)\nBrake rotor\tSRAM CenterLine, centerlock, round edge, 200mm (front), 180mm (rear)\n\nAccessories\nE-bike system\tBosch Performance Line CX\nBattery\tBosch Powertube 625Wh\nCharger\tBosch 4A compact charger\nController\tBosch Purion\nTool\tBontrager multi-tool, integrated storage bag\n\nWeight\nWeight\tM - 24.08 kg / 53.07 lbs (with TLR sealant, no tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 153 - 162 cm 5'0\" - 5'4\" | 67 - 74 cm 26\" - 29\" |\n| M | 161 - 172 cm 5'3\" - 5'8\" | 74 - 79 cm 29\" - 31\" |\n| L | 171 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| XL | 179 - 188 cm 5'10\" - 6'2\" | 84 - 89 cm 33\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 41.9 | 44.5 | 47.6 |\n| B — Seat tube angle | 76.1° | 76.1° | 76.1° | 76.1° |\n| C — Head tube length | 9.6 | 10.5 | 11.5 | 12.5 |\n| D — Head angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| E — Effective top tube | 58.6 | 61.3 | 64.0 | 66.7 |\n| F — Bottom bracket height | 34.0 | 34.0 | 34.0 | 34.0 |\n| G — Bottom bracket drop | 1.0 | 1.0 | 1.0 | 1.0 |\n| H — Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 10.5 | 10.5 | 10.5 | 10.5 |\n| K — Wheelbase | 119.5 | 122.3 | 125.0 | 127.8 |\n| L — Standover | 72.7 | 74.7 | 77.6 | 81.0 |\n|", - "price": 1499.99, - "tags": [ - "bicycle", - "electric bike", - "city bike" - ] - }, - { - "name": "Velo 99 XR1 AXS", - "shortDescription": "Velo 99 XR1 AXS is a next-generation bike designed for fast-paced adventure seekers and speed enthusiasts. Built for high-performance racing, the bike boasts state-of-the-art technology and premium components. It is the ultimate bike for riders who want to push their limits and get their adrenaline pumping.", - "description": "## Overview\nIt's right for you if...\nYou are a passionate cyclist looking for a bike that can keep up with your speed, agility, and endurance. You are an adventurer who loves to explore new terrains and challenge yourself on the toughest courses. You want a bike that is lightweight, durable, and packed with the latest technology.\n\nThe tech you get\nA lightweight, full carbon frame with advanced aerodynamics and integrated cable routing for a clean look. A high-performance SRAM XX1 Eagle AXS wireless electronic drivetrain, featuring a 12-speed cassette and a 32T chainring. A RockShox SID Ultimate fork with a remote lockout, 120mm travel, and Charger Race Day damper. A high-end SRAM G2 Ultimate hydraulic disc brake with carbon levers. A FOX Transfer SL dropper post for quick and easy height adjustments. DT Swiss XRC 1501 carbon wheels for superior speed and handling.\n\nThe final word\nVelo 99 XR1 AXS is a premium racing bike that can help you achieve your goals and reach new heights. It is designed for speed, agility, and performance, and it is packed with the latest technology and premium components. If you are a serious cyclist who wants the best, this is the bike for you.\n\n## Features\nAerodynamic design\n\nThe Velo 99 XR1 AXS features a state-of-the-art frame design that reduces drag and improves speed. It has an aerodynamic seatpost, integrated cable routing, and a sleek, streamlined look that sets it apart from other bikes.\n\nWireless electronic drivetrain\n\nThe SRAM XX1 Eagle AXS drivetrain features a wireless electronic system that provides precise, instant shifting and unmatched efficiency. It eliminates the need for cables and makes the bike lighter and faster.\n\nHigh-performance suspension\n\nThe RockShox SID Ultimate fork and Charger Race Day damper provide 120mm of smooth, responsive suspension that can handle any terrain. The fork also has a remote lockout for quick adjustments on the fly.\n\nSuperior braking power\n\nThe SRAM G2 Ultimate hydraulic disc brake system delivers unmatched stopping power and control. It has carbon levers for a lightweight, ergonomic design and precision control.\n\nCarbon wheels\n\nThe DT Swiss XRC 1501 carbon wheels are ultra-lightweight, yet incredibly strong and durable. They provide superior speed and handling, making the bike more agile and responsive.\n\n## Specs\nFrameset\nFrame\tFull carbon frame, integrated cable routing, aerodynamic design, Boost148\nFork\tRockShox SID Ultimate, Charger Race Day damper, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 120mm travel\n\nWheels\nWheel front\tDT Swiss XRC 1501 carbon wheel, Boost110, 15mm thru axle\nWheel rear\tDT Swiss XRC 1501 carbon wheel, SRAM XD driver, Boost148, 12mm thru axle\nTire\tSchwalbe Racing Ray, Performance Line, Addix, 29x2.25\"\nTire part\tSchwalbe Doc Blue Professional, 500ml\nMax tire size\t29x2.3\"\n\nDrivetrain\nShifter\tSRAM Eagle AXS, wireless, 12-speed\nRear derailleur\tSRAM XX1 Eagle AXS\nCrank\tSRAM XX1 Eagle, 32T, carbon\nChainring\tSRAM X-SYNC, 32T, alloy\nCassette\tSRAM Eagle XG-1299, 10-52, 12-speed\nChain\tSRAM XX1 Eagle, 12-speed\nMax chainring size\t1x: 32T\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tFOX Transfer SL, 125mm travel, internal routing, 31.6mm\nHandlebar\tBontrager Kovee Pro, ADV Carbon, 35mm, 5mm rise, 720mm width\nGrips\tBontrager XR Endurance Elite\nStem\tBontrager Kovee Pro, 35mm, Blendr compatible, 7 degree, 60mm length\nHeadset\tIntegrated, cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrake\tSRAM G2 Ultimate hydraulic disc, carbon levers, 180mm rotors\n\nAccessories\nBike computer\tBontrager Trip 300\nTool\tBontrager Flatline Pro pedal wrench, T25 Torx\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 158 - 168 cm 5'2\" - 5'6\" | 74 - 78 cm 29\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| L | 173 - 183 cm 5'8\" - 6'0\" | 82 - 86 cm 32\" - 34\" |\n| XL | 180 - 193 cm 5'11\" - 6'4\" | 86 - 90 cm 34\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.9 | 43.0 | 47.0 | 51.0 |\n| B — Seat tube angle | 74.5° | 74.5° | 74.5° | 74.5° |\n| C — Head tube length | 9.0 | 10.0 | 11.0 | 12.0 |\n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° |\n| E — Effective top tube | 57.8 | 59.7 | 61.6 | 63.6 |\n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 9.7 | 9.7 | 9.7 | 9.7 |\n| K — Wheelbase | 112.5 | 114.5 | 116.5 | 118.6 |\n| L — Standover | 75.9 | 77.8 | 81.5 | 84.2 |\n| M — Frame reach | 41.6 | 43.4 | 45.2 | 47.1 |\n| N — Frame stack | 58.2 | 58.9 | 59.3 | 59.9 |", - "price": 1099.99, - "tags": [ - "bicycle", - "mountain bike" - ] - }, - { - "name": "AURORA 11S E-MTB", - "shortDescription": "The AURORA 11S is a powerful and stylish electric mountain bike designed to take you on thrilling off-road adventures. With its sturdy frame and premium components, this bike is built to handle any terrain. It features a high-performance motor, long-lasting battery, and advanced suspension system that guarantee a smooth and comfortable ride.", - "description": "## Overview\nIt's right for you if...\nYou want a top-of-the-line e-MTB that is both powerful and stylish. You also want a bike that can handle any terrain, from steep climbs to rocky descents. With its advanced features and premium components, the AURORA 11S is designed for serious off-road riders who demand the best.\n\nThe tech you get\nA sturdy aluminum frame with advanced suspension system that provides 120mm of travel. A 750W brushless motor that delivers up to 28mph, and a 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge. An advanced 11-speed Shimano drivetrain with hydraulic disc brakes for precise shifting and reliable stopping power. \n\nThe final word\nThe AURORA 11S is a top-of-the-line e-MTB that delivers exceptional performance and style. Whether you're tackling steep climbs or hitting rocky descents, this bike is built to handle any terrain with ease. With its advanced features and premium components, the AURORA 11S is the perfect choice for serious off-road riders who demand the best.\n\n## Features\nPowerful and efficient\n\nThe AURORA 11S is equipped with a high-performance 750W brushless motor that delivers up to 28mph. The motor is powered by a long-lasting 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge.\n\nAdvanced suspension system\n\nThe bike's advanced suspension system provides 120mm of travel, ensuring a smooth and comfortable ride on any terrain. The front suspension is a Suntour XCR32 Air fork, while the rear suspension is a KS-281 hydraulic shock absorber.\n\nPremium components\n\nThe AURORA 11S features an advanced 11-speed Shimano drivetrain with hydraulic disc brakes. The bike is also equipped with a Tektro HD-E725 hydraulic disc brake system that provides reliable stopping power.\n\nSleek and stylish design\n\nWith its sleek and stylish design, the AURORA 11S is sure to turn heads on the trail. The bike's sturdy aluminum frame is available in a range of colors, including black, blue, and red.\n\n## Specs\nFrameset\nFrame Material: Aluminum\nFrame Size: S, M, L\nFork: Suntour XCR32 Air, 120mm Travel\nShock Absorber: KS-281 Hydraulic Shock Absorber\n\nWheels\nWheel Size: 27.5 inches\nTires: Kenda K1151 Nevegal, 27.5x2.35\nRims: Alloy Double Wall\nSpokes: 32H, Stainless Steel\n\nDrivetrain\nShifters: Shimano SL-M7000\nRear Derailleur: Shimano RD-M8000\nCrankset: Prowheel 42T, Alloy Crank Arm\nCassette: Shimano CS-M7000, 11-42T\nChain: KMC X11EPT\n\nBrakes\nBrake System: Tektro HD-E725 Hydraulic Disc Brake\nBrake Rotors: 180mm Front, 160mm Rear\n\nE-bike system\nMotor: 750W Brushless\nBattery: 48V/14Ah Lithium-Ion\nCharger: 48V/3A Smart Charger\nController: Intelligent Sinusoidal Wave\n\nWeight\nWeight: 59.5 lbs\n\n## Sizing & fit\n| Size | Rider Height | Standover Height |\n|------|-------------|-----------------|\n| S | 5'2\"-5'6\" | 28.5\" |\n| M | 5'7\"-6'0\" | 29.5\" |\n| L | 6'0\"-6'4\" | 30.5\" |\n\n## Geometry\nAll measurements provided in cm.\nSizing table\n| Frame size letter | S | M | L |\n|-------------------|-----|-----|-----|\n| Wheel Size | 27.5\"| 27.5\"| 27.5\"|\n| Seat tube length | 44.5| 48.5| 52.5|\n| Head tube angle | 68° | 68° | 68° |\n| Seat tube angle | 74.5°| 74.5°| 74.5°|\n| Effective top tube | 57.5| 59.5| 61.5|\n| Head tube length | 12.0| 12.0| 13.0|\n| Chainstay length | 45.5| 45.5| 45.5|\n| Bottom bracket height | 30.0| 30.0| 30.0|\n| Wheelbase | 115.0|116.5|118.5|", - "price": 1999.99, - "tags": [ - "bicycle", - "road bike" - ] - }, - { - "name": "VeloTech V9.5 AXS Gen 3", - "shortDescription": "VeloTech V9.5 AXS is a sleek and fast carbon bike that combines high-end tech with a comfortable ride. It's designed to provide the ultimate experience for the most serious riders. The bike comes with a lightweight and powerful motor that can be activated when needed, and you get a spec filled with premium parts.", - "description": "## Overview\nIt's right for you if...\nYou want a bike that is fast, efficient, and delivers an adrenaline-filled experience. You are looking for a bike that is built with cutting-edge technology, and you want a ride that is both comfortable and exciting.\n\nThe tech you get\nA lightweight and durable full carbon frame with a fork that has 100mm of travel. The bike comes with a powerful motor that can deliver up to 20 mph of assistance. The drivetrain is a wireless electronic system that is precise and reliable. The bike is also equipped with hydraulic disc brakes, tubeless-ready wheels, and comfortable grips.\n\nThe final word\nThe VeloTech V9.5 AXS is a high-end bike that delivers an incredible experience for serious riders. It combines the latest technology with a comfortable ride, making it perfect for long rides, tough climbs, and fast descents.\n\n## Features\nFast and efficient\nThe VeloTech V9.5 AXS comes with a powerful motor that can provide up to 20 mph of assistance. The motor is lightweight and efficient, providing a boost when you need it without adding bulk. The bike's battery is removable, allowing you to ride without assistance when you don't need it.\n\nSmart software for the trail\nThe VeloTech V9.5 AXS is equipped with intelligent software that delivers a smooth and responsive ride. The software allows the motor to respond immediately as you start to pedal, delivering more power over a wider cadence range. You can also customize your user settings to suit your preferences.\n\nComfortable ride\nThe VeloTech V9.5 AXS is designed to provide a comfortable ride, even on long rides. The bike's fork has 100mm of travel, providing ample cushioning for rough terrain. The bike's grips are also designed to provide a comfortable and secure grip, even on the most challenging rides.\n\n## Specs\nFrameset\nFrame\tCarbon fiber frame with internal cable routing and Boost148\nFork\t100mm of travel with remote lockout\nShock\tN/A\n\nWheels\nWheel front\tCarbon fiber tubeless-ready wheel\nWheel rear\tCarbon fiber tubeless-ready wheel\nSkewer rear\t12mm thru-axle\nTire\tTubeless-ready tire\nTire part\tTubeless sealant\n\nDrivetrain\nShifter\tWireless electronic shifter\nRear derailleur\tWireless electronic derailleur\nCrank\tCarbon fiber crankset with chainring\nCrank arm\tCarbon fiber crank arm\nChainring\tAlloy chainring\nCassette\t12-speed cassette\nChain\t12-speed chain\n\nComponents\nSaddle\tCarbon fiber saddle\nSeatpost\tCarbon fiber seatpost\nHandlebar\tCarbon fiber handlebar\nGrips\tComfortable and secure grips\nStem\tCarbon fiber stem\nHeadset\tCarbon fiber headset\nBrake\tHydraulic disc brakes\nBrake rotor\tDisc brake rotor\n\nAccessories\nE-bike system\tPowerful motor with removable battery\nBattery\tLithium-ion battery\nCharger\tFast charging adapter\nController\tHandlebar-mounted controller\nTool\tBasic toolkit\n\nWeight\nWeight\tM - 17.5 kg / 38.5 lbs (with tubeless sealant)\n\nWeight limit\nThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing & fit\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 160 - 170 cm 5'3\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| M | 170 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| L | 180 - 190 cm 5'11\" - 6'3\" | 84 - 89 cm 33\" - 35\" |\n| XL | 190 - 200 cm 6'3\" - 6'7\" | 89 - 94 cm 35\" - 37\" |\n\n## Geometry\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 50.0 | 53.3 | 55.6 | 58.8 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 43.2 | 48.3 | 53.3 |\n| B — Seat tube angle | 72.3° | 72.6° | 72.8° | 72.8° |\n| C — Head tube length | 9.0 | 10.0 | 10.5 | 11.0 |\n| D — Head angle | 67.5° | 67.5° | 67.5° | 67.5° |\n| E — Effective top tube | 58.0 | 61.7 | 64.8 | 67.0 |\n| F — Bottom bracket height | 32.3 | 32.3 | 32.3 | 32.3 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 44.7 | 44.7 | 44.7 | 44.7 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 |\n| K — Wheelbase | 112.6 | 116.5 | 119.7 | 121.9 |\n| L — Standover | 76.8 | 76.8 | 76.8 | 76.8 |\n| M — Frame reach | 40.5 | 44.0 | 47.0 | 49.0 |\n| N — Frame stack | 60.9 | 61.8 | 62.2 | 62.7 |", - "price": 1699.99, - "tags": [ - "bicycle", - "electric bike", - "city bike" - ] - }, - { - "name": "Axiom D8 E-Mountain Bike", - "shortDescription": "The Axiom D8 is an electrifying mountain bike that is built for adventure. It boasts a light aluminum frame, a powerful motor and the latest tech to tackle the toughest of terrains. The D8 provides assistance without adding bulk to the bike, giving you the flexibility to ride like a traditional mountain bike or have an extra push when you need it.", - "description": "## Overview \nIt's right for you if... \nYou're looking for an electric mountain bike that can handle a wide variety of terrain, from flowing singletrack to technical descents. You also want a bike that offers a powerful motor that provides assistance without adding bulk to the bike. The D8 is designed to take you anywhere, quickly and comfortably.\n\nThe tech you get \nA lightweight aluminum frame with 140mm of travel, a Suntour fork with hydraulic lockout, and a reliable and powerful Bafang M400 mid-motor that provides a boost up to 20 mph. The bike features a Shimano Deore drivetrain, hydraulic disc brakes, and a dropper seat post. With the latest tech on-board, the D8 is designed to take you to new heights.\n\nThe final word \nThe Axiom D8 is an outstanding electric mountain bike that is designed for adventure. It's built with the latest tech and provides the flexibility to ride like a traditional mountain bike or have an extra push when you need it. Whether you're a beginner or an experienced rider, the D8 is the perfect companion for your next adventure.\n\n## Features \nBuilt for Adventure \n\nThe D8 features a lightweight aluminum frame that is built to withstand rugged terrain. It comes equipped with 140mm of travel and a Suntour fork that can handle even the toughest of trails. With this bike, you're ready to take on anything the mountain can throw at you.\n\nPowerful Motor \n\nThe Bafang M400 mid-motor provides reliable and powerful assistance without adding bulk to the bike. You can quickly and easily switch between the different assistance levels to find the perfect balance between range and power.\n\nShimano Deore Drivetrain \n\nThe Shimano Deore drivetrain is reliable and offers smooth shifting on any terrain. You can easily adjust the gears to match your riding style and maximize your performance on the mountain.\n\nDropper Seat Post \n\nThe dropper seat post allows you to easily adjust your seat height on the fly, so you can maintain the perfect position for any terrain. With the flick of a switch, you can quickly and easily lower or raise your seat to match the terrain.\n\nHydraulic Disc Brakes \n\nThe D8 features powerful hydraulic disc brakes that offer reliable stopping power in any weather condition. You can ride with confidence knowing that you have the brakes to stop on a dime.\n\n## Specs \nFrameset \nFrame\tAluminum frame with 140mm of travel \nFork\tSuntour fork with hydraulic lockout, 140mm of travel \nShock\tN/A \nMax compatible fork travel\t140mm \n \nWheels \nWheel front\tAlloy wheel \nWheel rear\tAlloy wheel \nSkewer rear\tThru axle \nTire\t29\" x 2.35\" \nTire part\tN/A \nMax tire size\t29\" x 2.6\" \n \nDrivetrain \nShifter\tShimano Deore \nRear derailleur\tShimano Deore \nCrank\tBafang M400 \nCrank arm\tN/A \nChainring\tN/A \nCassette\tShimano Deore \nChain\tShimano Deore \nMax chainring size\tN/A \n \nComponents \nSaddle\tAxiom D8 saddle \nSeatpost\tDropper seat post \nHandlebar\tAxiom D8 handlebar \nGrips\tAxiom D8 grips \nStem\tAxiom D8 stem \nHeadset\tAxiom D8 headset \nBrake\tHydraulic disc brakes \nBrake rotor\t180mm \n\nAccessories \nE-bike system\tBafang M400 mid-motor \nBattery\tLithium-ion battery, 500Wh \nCharger\tLithium-ion charger \nController\tBafang M400 controller \nTool\tN/A \n \nWeight \nWeight\tM - 22 kg / 48.5 lbs \nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 136 kg (300 lbs). \n \n \n## Sizing & fit \n \n| Size | Rider Height | Inseam | \n|:----:|:------------------------:|:--------------------:| \n| S | 152 - 165 cm 5'0\" - 5'5\" | 70 - 76 cm 27\" - 30\" | \n| M | 165 - 178 cm 5'5\" - 5'10\" | 76 - 81 cm 30\" - 32\" | \n| L | 178 - 185 cm 5'10\" - 6'1\" | 81 - 86 cm 32\" - 34\" | \n| XL | 185 - 193 cm 6'1\" - 6'4\" | 86 - 91 cm 34\" - 36\" | \n \n \n## Geometry \n \nAll measurements provided in cm unless otherwise noted. \nSizing table \n| Frame size letter | S | M | L | XL | \n|---------------------------|-------|-------|-------|-------| \n| Actual frame size | 41.9 | 46.5 | 50.8 | 55.9 | \n| Wheel size | 29\" | 29\" | 29\" | 29\" | \n| A — Seat tube | 42.0 | 46.5 | 51.0 | 56.0 | \n| B — Seat tube angle | 74.0° | 74.0° | 74.0° | 74.0° | \n| C — Head tube length | 11.0 | 12.0 | 13.0 | 15.0 | \n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° | \n| E — Effective top tube | 57.0 | 60.0 | 62.0 | 65.0 | \n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 | \n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 | \n| H — Chainstay length | 46.0 | 46.0 | 46.0 | 46.0 | \n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | \n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 | \n| K — Wheelbase | 113.0 | 116.0 | 117.5 | 120.5 | \n| L — Standover | 73.5 | 75.5 | 76.5 | 79.5 | \n| M — Frame reach | 41.0 | 43.5 | 45.0 | 47.5 | \n| N — Frame stack | 60.5 | 61.5 | 62.5 | 64.5 |", - "price": 1399.99, - "tags": [ - "bicycle", - "electric bike", - "mountain bike" - ] - }, - { - "name": "Velocity X1", - "shortDescription": "Velocity X1 is a high-performance road bike designed for speed enthusiasts. It features a lightweight yet durable frame, aerodynamic design, and top-quality components, making it the perfect choice for those who want to take their cycling experience to the next level.", - "description": "## Overview\nIt's right for you if...\nYou're an experienced cyclist looking for a bike that can keep up with your need for speed. You want a bike that's lightweight, aerodynamic, and built to perform, whether you're training for a race or just pushing yourself to go faster.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork, Shimano Ultegra groupset with a wide range of gearing, hydraulic disc brakes, aerodynamic carbon wheels, and a vibration-absorbing handlebar with ergonomic grips.\n\nThe final word\nVelocity X1 is the ultimate road bike for speed enthusiasts. Its lightweight frame, aerodynamic design, and top-quality components make it the perfect choice for those who want to take their cycling experience to the next level.\n\n\n## Features\n\nAerodynamic design\nVelocity X1 is built with an aerodynamic design to help you go faster with less effort. It features a sleek profile, hidden cables, and a carbon fork that cuts through the wind, reducing drag and increasing speed.\n\nHydraulic disc brakes\nVelocity X1 comes equipped with hydraulic disc brakes, providing excellent stopping power in all weather conditions. They're also low maintenance, with minimal adjustments needed over time.\n\nCarbon wheels\nThe Velocity X1's aerodynamic carbon wheels provide excellent speed and responsiveness, helping you achieve your fastest times yet. They're also lightweight, reducing overall bike weight and making acceleration and handling even easier.\n\nShimano Ultegra groupset\nThe Shimano Ultegra groupset provides smooth shifting and reliable performance, ensuring you get the most out of every ride. With a wide range of gearing options, it's ideal for tackling any terrain, from steep climbs to fast descents.\n\n\n## Specifications\nFrameset\nFrame with Fork\tAluminium frame, internal cable routing, 135x9mm QR\nFork\tCarbon, hidden cable routing, 100x9mm QR\n\nWheels\nWheel front\tCarbon, 30mm deep rim, 23mm width, 100x9mm QR\nWheel rear\tCarbon, 30mm deep rim, 23mm width, 135x9mm QR\nSkewer front\t100x9mm QR\nSkewer rear\t135x9mm QR\nTire\tContinental Grand Prix 5000, 700x25mm, folding bead\nMax tire size\t700x28mm without fenders\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 11 speed\nRear derailleur\tShimano Ultegra R8000, 11 speed\n*Crank\tSize: S, M\nShimano Ultegra R8000, 50/34T, 170mm length\nSize: L, XL\nShimano Ultegra R8000, 50/34T, 175mm length\nBottom bracket\tShimano BB-RS500-PB, PressFit\nCassette\tShimano Ultegra R8000, 11-30T, 11 speed\nChain\tShimano Ultegra HG701, 11 speed\nPedal\tNot included\nMax chainring size\t50/34T\n\nComponents\nSaddle\tBontrager Montrose Comp, steel rails, 138mm width\nSeatpost\tBontrager Comp, 6061 alloy, 27.2mm, 8mm offset, 330mm length\n*Handlebar\tSize: S, M, L\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 400mm width\nSize: XL\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 420mm width\nGrips\tBontrager Supertack Perf tape\n*Stem\tSize: S, M, L\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 90mm length\nSize: XL\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 100mm length\nBrake\tShimano Ultegra R8070 hydraulic disc, flat mount\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.15 kg / 17.97 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" | 74 - 78 cm 29\" - 31\" |\n| M | 170 - 178 cm 5'7\" - 5'10\" | 77 - 82 cm 30\" - 32\" |\n| L | 178 - 186 cm 5'10\" - 6'1\" | 82 - 86 cm 32\" - 34\" |\n| XL | 186 - 196 cm 6'1\" - 6'5\" | 87 - 92 cm 34\" - 36\" |\n\n\n## Geometry\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.0 | 52.0 | 54.0 | 56.0 |\n| B — Seat tube angle | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 13.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 71.0° | 72.0° | 72.0° | 72.5° |\n| E — Effective top tube | 53.7 | 55.0 | 56.5 | 58.0 |\n| F — Bottom bracket height | 27.5 | 27.5 | 27.5 | 27.5 |\n| G — Bottom bracket drop | 7.3 | 7.3 | 7.3 | 7.3 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 5.8 |\n| K — Wheelbase | 98.2 | 99.1 | 100.1 | 101.0 |\n| L — Standover | 75.2 | 78.2 | 81.1 | 84.1 |\n| M — Frame reach | 37.5 | 38.3 | 39.1 | 39.9 |\n| N — Frame stack | 53.3 | 55.4 | 57.4 | 59.5 |", - "price": 1799.99, - "tags": [ - "bicycle", - "touring bike" - ] - }, - { - "name": "Velocity V9", - "shortDescription": "Velocity V9 is a high-performance hybrid bike that combines speed and comfort for riders who demand the best of both worlds. The lightweight aluminum frame, along with the carbon fork and seat post, provide optimal stiffness and absorption to tackle any terrain. A 2x Shimano Deore drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires make it a versatile ride for commuters, fitness riders, and weekend adventurers alike.", - "description": "## Overview\nIt's right for you if...\nYou want a fast, versatile bike that can handle anything from commuting to weekend adventures. You value comfort as much as speed and performance. You want a reliable and durable bike that will last for years to come.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork and seat post, a 2x Shimano Deore drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. The Velocity V9 is designed for riders who demand both performance and comfort in one package.\n\nThe final word\nThe Velocity V9 is the perfect bike for riders who want speed and performance without sacrificing comfort. The lightweight aluminum frame and carbon components provide optimal stiffness and absorption, while the 2x Shimano Deore drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're commuting, hitting the trails, or training for your next race, the Velocity V9 has everything you need to achieve your goals.\n\n## Features\n\n2x drivetrain\nA 2x drivetrain means more versatility and a wider range of gearing options. Whether you're climbing hills or sprinting on the flats, the Velocity V9 has the perfect gear for any situation.\n\nCarbon components\nThe Velocity V9 features a carbon fork and seat post to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unparalleled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\n## Specifications\nFrameset\nFrame with Fork\tAluminum frame with carbon fork and seat post, internal cable routing, fender mounts, 135x5mm ThruSkew\nFork\tCarbon fork, hidden fender mounts, flat mount disc, 5x100mm thru-skew\n\nWheels\nWheel front\tDouble wall aluminum rims, 700c, quick release hub\nWheel rear\tDouble wall aluminum rims, 700c, quick release hub\nTire\tKenda Kwick Tendril, puncture resistant, reflective sidewall, 700x32c\nMax tire size\t700x35c without fenders, 700x32c with fenders\n\nDrivetrain\nShifter\tShimano Deore, 10 speed\nFront derailleur\tShimano Deore\nRear derailleur\tShimano Deore\nCrank\tShimano Deore, 46-30T, 170mm (S/M), 175mm (L/XL)\nBottom bracket\tShimano BB52, 68mm, threaded\nCassette\tShimano Deore, 11-36T, 10 speed\nChain\tShimano HG54, 10 speed\nPedal\tWellgo alloy platform\n\nComponents\nSaddle\tVelo VL-2158, steel rails\nSeatpost\tCarbon seat post, 27.2mm\nHandlebar\tAluminum, 31.8mm clamp, 15mm rise, 680mm width\nGrips\tVelo ergonomic grips\nStem\tAluminum, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, MT200 lever, MT200 caliper\nBrake rotor\tShimano RT56, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 11.5 kg / 25.35 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 44.0 | 48.0 | 52.0 | 56.0 |\n| B — Seat tube angle | 74.5° | 74.0° | 73.5° | 73.0° |\n| C — Head tube length | 14.5 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 71.0° | 71.0° | 71.5° | 71.5° |\n| E — Effective top tube | 56.5 | 57.5 | 58.5 | 59.5 |\n| F — Bottom bracket height | 27.0 | 27.0 | 27.0 | 27.0 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 7.0 | 7.0 | 6.6 | 6.6 |\n| K — Wheelbase | 105.4 | 106.3 | 107.2 | 108.2 |\n| L — Standover | 73.2 | 77.1 | 81.2 | 85.1 |\n| M — Frame reach | 39.0 | 39.8 | 40.4 | 41.3 |\n| N — Frame stack | 57.0 | 58.5 | 60.0 | 61.5 |", - "price": 2199.99, - "tags": [ - "bicycle", - "electric bike", - "mountain bike" - ] - }, - { - "name": "Aero Pro X", - "shortDescription": "Aero Pro X is a high-end racing bike designed for serious cyclists who demand speed, agility, and superior performance. The lightweight carbon frame and fork, combined with the aerodynamic design, provide optimal stiffness and efficiency to maximize your speed. The bike features a 2x Shimano Ultegra drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires. Whether you're competing in a triathlon or climbing steep hills, Aero Pro X delivers exceptional performance and precision handling.", - "description": "## Overview\nIt's right for you if...\nYou are a competitive cyclist looking for a bike that is designed for racing. You want a bike that delivers exceptional speed, agility, and precision handling. You demand superior performance and reliability from your equipment.\n\nThe tech you get\nA lightweight carbon frame with an aerodynamic design, a carbon fork with hidden fender mounts, a 2x Shimano Ultegra drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. Aero Pro X is designed for serious cyclists who demand nothing but the best.\n\nThe final word\nAero Pro X is the ultimate racing bike for serious cyclists. The lightweight carbon frame and aerodynamic design deliver maximum speed and efficiency, while the 2x Shimano Ultegra drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're competing in a triathlon or a criterium race, Aero Pro X delivers the performance you need to win.\n\n## Features\n\nAerodynamic design\nThe Aero Pro X features an aerodynamic design that reduces drag and maximizes efficiency. The bike is optimized for speed and agility, so you can ride faster and farther with less effort.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unrivaled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\nCarbon components\nThe Aero Pro X features a carbon fork with hidden fender mounts to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\n## Specifications\nFrameset\nFrame with Fork\tCarbon frame with an aerodynamic design, internal cable routing, 3s chain keeper, 142x12mm thru-axle\nFork\tCarbon fork with hidden fender mounts, flat mount disc, 100x12mm thru-axle\n\nWheels\nWheel front\tDouble wall carbon rims, 700c, thru-axle hub\nWheel rear\tDouble wall carbon rims, 700c, thru-axle hub\nTire\tContinental Grand Prix 5000, folding bead, 700x25c\nMax tire size\t700x28c without fenders, 700x25c with fenders\n\nDrivetrain\nShifter\tShimano Ultegra, 11 speed\nFront derailleur\tShimano Ultegra\nRear derailleur\tShimano Ultegra\nCrank\tShimano Ultegra, 52-36T, 170mm (S), 172.5mm (M), 175mm (L/XL)\nBottom bracket\tShimano BB72, 68mm, PressFit\nCassette\tShimano Ultegra, 11-30T, 11 speed\nChain\tShimano HG701, 11 speed\nPedal\tNot included\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tCarbon seat post, 27.2mm, 20mm offset\nHandlebar\tBontrager XXX Aero, carbon, 31.8mm clamp, 75mm reach, 125mm drop\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Pro, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, Ultegra lever, Ultegra caliper\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.36 kg / 18.42 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.6 | 52.4 | 54.3 | 56.2 |\n| B — Seat tube angle | 75.5° | 74.5° | 73.5° | 72.5° |\n| C — Head tube length | 12.0 | 14.0 | 16.0 | 18.0 |\n| D — Head angle | 72.5° | 73.0° | 73.5° | 74.0° |\n| E — Effective top tube | 53.8 | 55.4 | 57.0 | 58.6 |\n| F — Bottom bracket height | 26.5 | 26.5 | 26.5 | 26.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 6.0 |\n| K — Wheelbase | 97.1 | 98.7 | 100.2 | 101.8 |\n| L — Standover | 73.8 | 76.2 | 78.5 | 80.8 |\n| M — Frame reach | 38.8 | 39.5 | 40.2 | 40.9 |\n| N — Frame stack | 52.8 | 54.7 | 56.6 | 58.5 |", - "price": 1599.99, - "tags": [ - "bicycle", - "road bike" - ] - }, - { - "name": "Voltex+ Ultra Lowstep", - "shortDescription": "Voltex+ Ultra Lowstep is a high-performance electric hybrid bike designed for riders who seek speed, comfort, and reliability during their everyday rides. Equipped with a powerful and efficient Voltex Drive Pro motor and a fully-integrated 600Wh battery, this e-bike allows you to cover longer distances on a single charge. The Voltex+ Ultra Lowstep comes with premium components that prioritize comfort and safety, such as a suspension seatpost, wide and stable tires, and integrated lights.", - "description": "## Overview\n\nIt's right for you if...\nYou want an e-bike that provides a boost for faster rides and effortless usage. Durability is crucial, and you need a bike with one of the most powerful and efficient motors.\n\nThe tech you get\nA lightweight Delta Carbon Fiber frame with an ultra-lowstep design, a Voltex Drive Pro (350W, 75Nm) motor capable of maintaining speeds up to 30 mph, an extended range 600Wh battery integrated into the frame, and a Voltex Control Panel. Additionally, it features a 12-speed Shimano drivetrain, hydraulic disc brakes for optimal all-weather stopping power, a suspension seatpost, wide puncture-resistant tires for added stability, ergonomic grips, a kickstand, lights, and a cargo rack.\n\nThe final word\nThis bike offers enhanced enjoyment and ease of use on long commutes, leisure rides, and adventures. With its extended-range battery, powerful Voltex motor, user-friendly controller, and a seatpost that smooths out road vibrations, it guarantees an exceptional riding experience.\n\n## Features\n\nUltra-fast assistance\n\nExperience speeds up to 30 mph with the cutting-edge Voltex Drive Pro motor, allowing you to breeze through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\n- Frame: Delta Carbon Fiber, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Voltex Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: Voltex Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: Voltex E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore XT M8100, 12-speed\n- Rear derailleur: Shimano Deore XT M8100, long cage\n- Crank: Voltex alloy, 170mm length\n- Chainring: FSA, 44T, aluminum with guard\n- Cassette: Shimano Deore XT M8100, 10-51, 12-speed\n- Chain: KMC E12 Turbo\n- Pedal: Voltex Urban pedals\n\nComponents\n- Saddle: Voltex Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar: Voltex alloy, 31.8mm, comfort sweep, 620mm width (XS, S, M), 660mm width (L)\n- Grips: Voltex Satellite Elite, alloy lock-on\n- Stem: Voltex alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length (XS, S), 105mm length (M, L)\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT520 hydraulic disc\n- Brake rotor: Shimano RT56, 6-bolt, 180mm (XS, S, M, L), 160mm (XS, S, M, L)\n\nAccessories\n- Battery: Voltex PowerTube 600Wh\n- Charger: Voltex compact 2A, 100-240V\n- Computer: Voltex Control Panel\n- Motor: Voltex Drive Pro, 75Nm, 30mph\n- Light: Voltex Solo for e-bike, taillight (XS, S, M, L), Voltex MR8, 180 lumen, 60 lux, LED, headlight (XS, S, M, L)\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: Voltex-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender: Voltex wide (XS, S, M, L), Voltex plastic (XS, S, M, L)\n\nWeight\n- Weight: M - 20.50 kg / 45.19 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 330 pounds (150 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 38.0 | 43.0 | 48.0 | 53.0 |\n| B — Seat tube angle | 70.5° | 70.5° | 70.5° | 70.5° |\n| C — Head tube length | 15.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 69.2° | 69.2° | 69.2° | 69.2° |\n| E — Effective top tube | 57.2 | 57.7 | 58.8 | 60.0 |\n| F — Bottom bracket height | 30.3 | 30.3 | 30.3 | 30.3 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.5 | 48.5 | 48.5 | 48.5 |\n| I — Offset | 5.0 | 5.0 | 5.0 | 5.0 |\n| J — Trail | 9.0 | 9.0 | 9.0 | 9.0 |\n| K — Wheelbase | 111.8 | 112.3 | 113.6 | 114.8 |\n| L — Standover | 42.3 | 42.3 | 42.3 | 42.3 |\n| M — Frame reach | 36.0 | 38.0 | 38.0 | 38.0 |\n| N — Frame stack | 62.0 | 62.0 | 63.9 | 65.8 |\n| Stem length | 8.0 | 8.5 | 8.5 | 10.5 |\n\nPlease note that the specifications and features listed above are subject to change and may vary based on different models and versions of the Voltex+ Ultra Lowstep bike.", - "price": 2999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "SwiftRide Hybrid", - "shortDescription": "SwiftRide Hybrid is a versatile and efficient bike designed for riders who want a smooth and enjoyable ride on various terrains. It incorporates advanced technology and high-quality components to provide a comfortable and reliable cycling experience.", - "description": "## Overview\n\nIt's right for you if...\nYou are looking for a bike that combines the benefits of an electric bike with the versatility of a hybrid. You value durability, speed, and ease of use.\n\nThe tech you get\nThe SwiftRide Hybrid features a lightweight and durable aluminum frame, making it easy to handle and maneuver. It is equipped with a powerful electric motor that offers a speedy assist, helping you reach speeds of up to 25 mph. The bike comes with a removable and fully-integrated 500Wh battery, providing a long-range capacity for extended rides. It also includes a 10-speed Shimano drivetrain, hydraulic disc brakes for precise stopping power, wide puncture-resistant tires for stability, and integrated lights for enhanced visibility.\n\nThe final word\nThe SwiftRide Hybrid is designed for riders who want a bike that can handle daily commutes, recreational rides, and adventures. With its efficient motor, intuitive controls, and comfortable features, it offers an enjoyable and hassle-free riding experience.\n\n## Features\n\nEfficient electric assist\nExperience the thrill of effortless riding with the powerful electric motor that provides a speedy assist, making your everyday rides faster and more enjoyable.\n\n## Specs\n\nFrameset\n- Frame: Lightweight Aluminum, Removable Integrated Battery (RIB), rack & fender mounts, internal routing, 135x5mm QR\n- Fork: SwiftRide Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: SwiftRide Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: SwiftRide E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: SwiftRide City pedals\n\nComponents\n- Saddle: SwiftRide Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - SwiftRide alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - SwiftRide alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: SwiftRide Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 85mm length\n - Size: M, L - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: SwiftRide PowerTube 500Wh\n- Charger: SwiftRide compact 2A, 100-240V\n- Computer: SwiftRide Purion\n- Motor: SwiftRide Performance Line Sport, 65Nm, 25mph\n- Light:\n - Size: XS, S, M, L - SwiftRide SOLO for e-bike, taillight\n - Size: XS, S, M, L - SwiftRide MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: SwiftRide-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SwiftRide wide\n - Size: XS, S, M, L - SwiftRide plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm (4'10\" - 5'1\") | 69 - 73 cm (27\" - 29\") |\n| S | 155 - 165 cm (5'1\" - 5'5\") | 72 - 78 cm (28\" - 31\") |\n| M | 165 - 175 cm (5'5\" - 5'9\") | 77 - 83 cm (30\" - 33\") |\n| L | 175 - 186 cm (5'9\" - 6'1\") | 82 - 88 cm (32\" - 35\") |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 3999.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "RoadRunner E-Speed Lowstep", - "shortDescription": "RoadRunner E-Speed Lowstep is a high-performance electric hybrid designed for riders seeking speed and excitement on their daily rides. It is equipped with a powerful and reliable ThunderBolt drive unit that offers exceptional acceleration. The bike features a fully-integrated 500Wh battery, allowing riders to cover longer distances on a single charge. With its comfortable and safe components, including a suspension seatpost, wide and stable tires, and integrated lights, the RoadRunner E-Speed Lowstep ensures a smooth and enjoyable ride.", - "description": "## Overview\n\nIt's right for you if...\nYou're looking for an e-bike that provides an extra boost to reach your destination quickly and effortlessly. You prioritize durability and want a bike with one of the fastest motors available.\n\nThe tech you get\nA lightweight and sturdy ThunderBolt aluminum frame with a lowstep geometry. The bike is equipped with a ThunderBolt Performance Sport (250W, 65Nm) drive unit capable of reaching speeds up to 28 mph. It features a long-range 500Wh battery fully integrated into the frame and a ThunderBolt controller. Additionally, the bike has a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe RoadRunner E-Speed Lowstep is designed to provide enjoyment and ease of use on longer commutes, recreational rides, and adventurous journeys. Its long-range battery, fast ThunderBolt motor, intuitive controller, and road-smoothing suspension seatpost make it the perfect choice for riders seeking both comfort and speed.\n\n## Features\n\nSuper speedy assist\n\nThe ThunderBolt Performance Sport drive unit allows you to accelerate up to 28mph, making errands, commutes, and joyrides a breeze.\n\n## Specs\n\nFrameset\n- Frame: ThunderBolt Smooth Aluminum, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: RoadRunner Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: ThunderBolt DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: ThunderBolt DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: ThunderBolt Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: ThunderBolt E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: RoadRunner City pedals\n\nComponents\n- Saddle: RoadRunner Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - RoadRunner alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - RoadRunner alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: RoadRunner Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: ThunderBolt PowerTube 500Wh\n- Charger: ThunderBolt compact 2A, 100-240V\n- Computer: ThunderBolt Purion\n- Motor: ThunderBolt Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - ThunderBolt SOLO for e-bike, taillight\n - Size: XS, S, M, L - ThunderBolt MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - RoadRunner wide\n - Size: XS, S, M, L - RoadRunner plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 4999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "Hyperdrive Turbo X1", - "shortDescription": "Hyperdrive Turbo X1 is a high-performance electric bike designed for riders seeking an exhilarating experience on their daily rides. It features a powerful and efficient Hyperdrive Sport drive unit and a sleek, integrated 500Wh battery for extended range. This e-bike is equipped with top-of-the-line components prioritizing comfort and safety, including a suspension seatpost, wide and stable tires, and integrated lights.", - "description": "## Overview\n\nIt's right for you if...\nYou crave the thrill of an e-bike that can accelerate rapidly, reaching high speeds effortlessly. You value durability and are looking for a bike that is equipped with one of the fastest motors available.\n\nThe tech you get\nA lightweight Hyper Alloy frame with a lowstep geometry, a Hyperdrive Sport (300W, 70Nm) drive unit capable of maintaining speeds up to 30 mph, a long-range 500Wh battery seamlessly integrated into the frame, and an intuitive Hyper Control controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for enhanced stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThis bike is designed for riders seeking enjoyment and convenience on longer commutes, recreational rides, and thrilling adventures. With its long-range battery, high-speed motor, user-friendly controller, and smooth-riding suspension seatpost, the Hyperdrive Turbo X1 guarantees an exceptional e-biking experience.\n\n## Features\n\nHyperboost Acceleration\nExperience adrenaline-inducing rides with the powerful Hyperdrive Sport drive unit that enables quick acceleration and effortless cruising through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\nFrame\tHyper Alloy, Removable Integrated Battery (RIB), seamless welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\nFork\tHyper Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\nMax compatible fork travel\t50mm\n\nWheels\nHub front\tFormula DC-20, alloy, 6-bolt, 5x100mm QR\nSkewer front\t132x5mm QR, ThruSkew\nHub rear\tFormula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\nSkewer rear\t153x5mm bolt-on\nRim\tHyper Connection, double-wall, 32-hole, 20 mm width, Schrader valve\nTire\tHyper E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\nMax tire size\t700x50mm with or without fenders\n\nDrivetrain\nShifter\tShimano Deore M4100, 10 speed\nRear derailleur\tShimano Deore M5120, long cage\nCrank\tProWheel alloy, 170mm length\nChainring\tFSA, 42T, steel w/guard\nCassette\tShimano Deore M4100, 11-42, 10 speed\nChain\tKMC E10\nPedal\tHyper City pedals\n\nComponents\nSaddle\tHyper Boulevard\nSeatpost\tAlloy, suspension, 31.6mm, 300mm length\n*Handlebar\tSize: XS, S, M\nHyper alloy, 31.8mm, comfort sweep, 620mm width\nSize: L\nHyper alloy, 31.8mm, comfort sweep, 660mm width\nGrips\tHyper Satellite Elite, alloy lock-on\n*Stem\tSize: XS, S\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\nSize: M, L\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\nHeadset\tVP sealed cartridge, 1-1/8'', threaded\nBrake\tShimano MT200 hydraulic disc\n*Brake rotor\tSize: XS, S, M, L\nShimano RT26, 6-bolt,180mm\nSize: XS, S, M, L\nShimano RT26, 6-bolt,160mm\n\nAccessories\nBattery\tHyper PowerTube 500Wh\nCharger\tHyper compact 2A, 100-240V\nComputer\tHyper Control\nMotor\tHyperdrive Sport, 70Nm, 30mph\n*Light\tSize: XS, S, M, L\nSpanninga SOLO for e-bike, taillight\nSize: XS, S, M, L\nHerrmans MR8, 180 lumen, 60 lux, LED, headlight\nKickstand\tAdjustable length rear mount alloy kickstand\nCargo rack\tMIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n*Fender\tSize: XS, S, M, L\nSKS wide\nSize: XS, S, M, L\nSKS plastic\n\nWeight\nWeight\tM - 22.30 kg / 49.17 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 1999.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - }, - { - "name": "Horizon+ Evo Lowstep", - "shortDescription": "The Horizon+ Evo Lowstep is a versatile electric hybrid bike designed for riders seeking a thrilling and efficient riding experience on a variety of terrains. With its powerful Bosch Performance Line Sport drive unit and integrated 500Wh battery, this e-bike enables riders to cover long distances with ease. Equipped with features prioritizing comfort and safety, such as a suspension seatpost, stable tires, and integrated lights, the Horizon+ Evo Lowstep is a reliable companion for everyday rides.", - "description": "## Overview\n\nIt's right for you if...\nYou desire the convenience and speed of an e-bike to enhance your riding, and you want an intuitive and durable bicycle. You prioritize having one of the fastest motors developed by Bosch.\n\nThe tech you get\nA lightweight Alpha Smooth Aluminum frame with a lowstep geometry, a Bosch Performance Line Sport (250W, 65Nm) drive unit capable of sustaining speeds up to 28 mph, a fully encased 500Wh battery integrated into the frame, and a Bosch Purion controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for improved stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe Horizon+ Evo Lowstep offers an enjoyable and user-friendly riding experience for longer commutes, recreational rides, and adventures. It boasts an extended range battery, a high-performance Bosch motor, an intuitive controller, and a suspension seatpost for a smooth ride on various road surfaces.\n\n## Features\n\nSuper speedy assist\nExperience effortless cruising through errands, commutes, and joyrides with the new Bosch Performance Sport drive unit, allowing acceleration of up to 28 mph.\n\n## Specs\n\nFrameset\n- Frame: Alpha Platinum Aluminum, Removable Integrated Battery (RIB), smooth welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Horizon Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Front Hub: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Front Skewer: 132x5mm QR, ThruSkew\n- Rear Hub: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Rear Skewer: 153x5mm bolt-on\n- Rim: Bontrager Connection, double-wall, 32-hole, 20mm width, Schrader valve\n- Tire: Bontrager E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10-speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10-speed\n- Chain: KMC E10\n- Pedal: Bontrager City pedals\n\nComponents\n- Saddle: Bontrager Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - Bontrager alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - Bontrager alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: Bontrager Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8\", threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: Bosch PowerTube 500Wh\n- Charger: Bosch compact 2A, 100-240V\n- Computer: Bosch Purion\n- Motor: Bosch Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - Spanninga SOLO for e-bike, taillight\n - Size: XS, S, M, L - Herrmans MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SKS wide\n - Size: XS, S, M, L - SKS plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", - "price": 4499.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "FastRider X1", - "shortDescription": "FastRider X1 is a high-performance e-bike designed for riders seeking speed and long-distance capabilities. Equipped with a powerful motor and a high-capacity battery, the FastRider X1 is perfect for daily commuters and e-bike enthusiasts. It boasts a sleek and functional design, making it a great alternative to car transportation. The bike also features a smartphone controller for easy navigation and entertainment options.", - "description": "## Overview\nIt's right for you if...\nYou're looking for an e-bike that offers both speed and endurance. The FastRider X1 comes with a high-performance motor and a long-lasting battery, making it ideal for long-distance rides.\n\nThe tech you get\nThe FastRider X1 features a state-of-the-art motor and a spacious battery, ensuring a fast and efficient ride.\n\nThe final word\nWith the powerful motor and long-range battery, the FastRider X1 allows you to cover more distance at higher speeds.\n\n## Features\nConnect Your Ride with the FastRider App\nDownload the FastRider app and transform your smartphone into an on-board computer. Easily dock and charge your phone with the smartphone controller, and use the thumb pad on your handlebar to make calls, listen to music, get turn-by-turn directions, and more. The app also allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nGoodbye, Car. Hello, Extended Range!\nWith the option to add the Range Boost feature, you can attach a second long-range battery to your FastRider X1, doubling the distance and time between charges. This enhancement allows you to ride longer, commute farther, and take on more adventurous routes.\n\nWhat is the range?\nTo estimate the distance you can travel on a single charge, use our range calculator tool. It automatically fills in the variables for this specific bike model and assumes an average rider, but you can adjust the settings to get the most accurate estimate for your needs.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: FastRider rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: FastRider sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: FastRider Switch thru axle, removable lever\n- Rear Hub: FastRider alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: FastRider MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: FastRider E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - FastRider alloy, 170mm length / Size: L, XL - FastRider alloy, 175mm length\n- Chainring: FastRider 46T narrow/wide alloy, w/alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10 / Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - FastRider City pedals / Size: M, L, XL - Wellgo C157, boron axle, plastic body / Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: FastRider Commuter Comp\n- Seatpost: FastRider Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - FastRider alloy, 31.8mm, 15mm rise, 600mm width / Size: L, XL - FastRider alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: FastRider Satellite Elite, alloy lock-on\n- Stem: Size: M - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length / Size: L - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length / Size: XL - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom / Size: M, L, XL - FSA Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: FastRider PowerTube 625Wh\n- Charger: FastRider standard 4A, 100-240V\n- Motor: FastRider Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - FastRider taillight, 50 lumens / Size: M, L, XL - FastRider headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy / Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: FastRider integrated rear rack, aluminum\n- Fender: FastRider custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n\nWeight limit\n- This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 5499.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "SonicRide 8S", - "shortDescription": "SonicRide 8S is a high-performance e-bike designed for riders who crave speed and long-distance capabilities. The advanced SonicDrive motor provides powerful assistance up to 28 mph, combined with a durable and long-lasting battery for extended rides. With its sleek design and thoughtful features, the SonicRide 8S is perfect for those who prefer the freedom of riding a bike over driving a car. Plus, it comes equipped with a smartphone controller for easy navigation, music, and more.", - "description": "## Overview\nIt's right for you if...\nYou want a fast and efficient e-bike that can take you long distances. The SonicRide 8S features a hydroformed aluminum frame with a concealed 625Wh battery, a high-powered SonicDrive motor, and a Smartphone Controller. It also includes essential accessories such as lights, fenders, and a rear rack.\n\nThe tech you get\nThe SonicRide 8S is equipped with the fastest SonicDrive motor, ensuring exhilarating rides at high speeds. The long-range battery is perfect for commuters and riders looking to explore new horizons.\n\nThe final word\nWith the SonicDrive motor and long-lasting battery, you can enjoy extended rides at higher speeds.\n\n## Features\n\nConnect Your Ride with SonicRide App\nDownload the SonicRide app and transform your phone into an onboard computer. Simply attach it to the Smartphone Controller for docking and charging. Use the thumb pad on your handlebar to control calls, music, directions, and more. The Bluetooth® wireless technology allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nSay Goodbye to Limited Range with Range Boost!\nExperience the convenience of Range Boost, an additional long-range 500Wh battery that seamlessly attaches to your bike's down tube. This upgrade allows you to double your distance and time between charges, enabling longer commutes and more adventurous rides. Range Boost is compatible with select SonicRide electric bike models.\n\nWhat is the range?\nFor an accurate estimate of how far you can ride on a single charge, use SonicRide's range calculator. We have pre-filled the variables for this specific bike model and the average rider, but you can adjust them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: SonicRide rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: SonicRide sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: SonicRide Switch thru axle, removable lever\n- Rear Hub: SonicRide alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SonicRide MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: SonicRide E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - SonicRide alloy, 170mm length; Size: L, XL - SonicRide alloy, 175mm length\n- Chainring: SonicRide 46T narrow/wide alloy, with alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10; Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - SonicRide City pedals; Size: M, L, XL - Wellgo C157, boron axle, plastic body; Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: SonicRide Commuter Comp\n- Seatpost: SonicRide Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - SonicRide alloy, 31.8mm, 15mm rise, 600mm width; Size: L, XL - SonicRide alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: SonicRide Satellite Elite, alloy lock-on\n- Stem: Size: M - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length; Size: L - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length; Size: XL - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - SonicRide IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom; Size: M, L, XL - SonicRide Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: SonicRide PowerTube 625Wh\n- Charger: SonicRide standard 4A, 100-240V\n- Motor: SonicRide Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - SonicRide Lync taillight, 50 lumens; Size: M, L, XL - SonicRide Lync headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy; Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: SonicRide integrated rear rack, aluminum\n- Fender: SonicRide custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm / 5'5\" - 5'9\" | 77 - 83 cm / 30\" - 33\" |\n| L | 175 - 186 cm / 5'9\" - 6'1\" | 82 - 88 cm / 32\" - 35\" |\n| XL | 186 - 197 cm / 6'1\" - 6'6\" | 87 - 93 cm / 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |", - "price": 5999.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "SwiftVolt Pro", - "shortDescription": "SwiftVolt Pro is a high-performance e-bike designed for riders seeking a thrilling and fast riding experience. Equipped with a powerful SwiftDrive motor that provides assistance up to 30 mph and a long-lasting battery, this bike is perfect for long-distance commuting and passionate e-bike enthusiasts. The sleek and innovative design features cater specifically to individuals who prioritize cycling over driving. Additionally, the bike is seamlessly integrated with your smartphone, allowing you to use it for navigation, music, and more.", - "description": "## Overview\nThis bike is ideal for you if:\n- You desire a sleek and modern hydroformed aluminum frame that houses a 700Wh battery.\n- You want to maintain high speeds of up to 30 mph with the assistance of the SwiftDrive motor.\n- You appreciate the convenience of using your smartphone as a controller, which can be docked and charged on the handlebar.\n\n## Features\n\nConnect with SwiftSync App\nBy downloading the SwiftSync app, your smartphone becomes an interactive on-board computer. Attach it to the handlebar-mounted controller for easy access and charging. With the thumb pad, you can make calls, listen to music, receive turn-by-turn directions, and connect with fitness and health apps to track your routes and ride data via Bluetooth® wireless technology.\n\nEnhanced Range with BoostMax\nBoostMax offers the capability to attach a second 700Wh Swift battery to the downtube of your bike, effectively doubling the distance and time between charges. This allows for extended rides, longer commutes, and more significant adventures. BoostMax is compatible with select Swift electric bike models.\n\nRange Estimation\nFor an estimate of how far you can ride on a single charge, consult the Swift range calculator. The variables are automatically populated based on this bike model and the average rider, but you can modify them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: Lightweight hydroformed alloy, Removable Integrated Battery, BoostMax-compatible, internal cable routing, post-mount disc, 135x5 mm QR\n- Fork: SwiftVolt rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: Swift sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: Swift Switch thru-axle, removable lever\n- Rear Hub: Swift alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SwiftRim, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: 14g stainless steel, black\n- Tire: Swift E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: Swift alloy, 170mm length\n- Chainring: Swift 46T narrow/wide alloy, w/alloy guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: Swift City pedals\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: Swift Commuter Comp\n- Seatpost: Swift Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Swift alloy, 31.8mm, 15mm rise, 600mm width (M), 660mm width (L, XL)\n- Grips: Swift Satellite Elite, alloy lock-on\n- Stem: Swift alloy, 31.8mm, Blendr compatible, 7 degree, 70mm length (M), 90mm length (L), 100mm length (XL)\n- Headset: FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brakes: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake Rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max 180mm front & rear\n\nAccessories\n- Battery: Swift PowerTube 700Wh\n- Charger: Swift standard 4A, 100-240V\n- Motor: SwiftDrive, 90 Nm, 30 mph / 48 kph\n- Light: Swift Lync taillight, 50 lumens (M, L, XL), Swift Lync headlight, 500 lumens (M, L, XL)\n- Kickstand: Rear mount, alloy (M, L, XL), Adjustable length alloy kickstand (M, L, XL)\n- Cargo rack: SwiftVolt integrated rear rack, aluminum\n- Fender: Swift custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:-------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 2499.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - }, - { - "name": "AgileEon 9X", - "shortDescription": "AgileEon 9X is a high-performance e-bike designed for riders seeking speed and endurance. Equipped with a robust motor and an extended battery life, this bike is perfect for long-distance commuters and avid e-bike enthusiasts. It boasts innovative features tailored for individuals who prioritize cycling over driving. Additionally, the bike integrates seamlessly with your smartphone, allowing you to access navigation, music, and more.", - "description": "## Overview\nIt's right for you if...\nYou crave speed and want to cover long distances efficiently. The AgileEon 9X features a sleek hydroformed aluminum frame that houses a powerful motor, along with a large-capacity battery for extended rides. It comes equipped with a 10-speed drivetrain, front and rear lighting, fenders, and a rear rack.\n\nThe tech you get\nDesigned for those constantly on the move, this bike includes a state-of-the-art motor and a high-capacity battery, making it an excellent choice for lengthy commutes.\n\nThe final word\nWith the AgileEon 9X, you can push your boundaries and explore new horizons thanks to its powerful motor and long-lasting battery.\n\n## Features\n\nConnect Your Ride with RideMate App\nMake use of the RideMate app to transform your smartphone into an onboard computer. Simply attach it to the RideMate controller to dock and charge, then utilize the thumb pad on your handlebar to make calls, listen to music, receive turn-by-turn directions, and more. The bike also supports Bluetooth® wireless technology, enabling seamless connectivity with fitness and health apps for route syncing and ride data.\n\nGoodbye, car. Hello, Extended Range!\nEnhance your riding experience with the Extended Range option, which allows for the attachment of an additional high-capacity 500Wh battery to your bike's downtube. This doubles the distance and time between charges, enabling longer rides, extended commutes, and more significant adventures. The Extended Range feature is compatible with select AgileEon electric bike models.\n\nWhat is the range?\nTo determine how far you can ride on a single charge, you can utilize the range calculator provided by AgileEon. We have pre-filled the variables for this specific model and an average rider, but adjustments can be made for a more accurate estimation.\n\n## Specifications\nFrameset\nFrame: High-performance hydroformed alloy, Removable Integrated Battery, Extended Range-compatible, internal cable routing, Motor Armor, post-mount disc, 135x5 mm QR\nFork: AgileEon rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\nMax compatible fork travel: 63mm\n\nWheels\nFront Hub: AgileEon sealed bearing, 32-hole 15mm alloy thru-axle\nFront Skewer: AgileEon Switch thru-axle, removable lever\nRear Hub: AgileEon alloy, sealed bearing, 6-bolt, 135x5mm QR\nRear Skewer: 148x5mm bolt-on\nRim: AgileEon MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\nSpokes:\n- Size: M, L, XL: 14g stainless steel, black\nTire: AgileEon E6 Hard-Case Lite, reflective strip, 27.5x2.40''\nMax tire size: 27.5x2.40\"\n\nDrivetrain\nShifter: Shimano Deore M4100, 10-speed\nRear derailleur:\n- Size: M, L, XL: Shimano Deore M5120, long cage\nCrank:\n- Size: M: AgileEon alloy, 170mm length\n- Size: L, XL: AgileEon alloy, 175mm length\nChainring: AgileEon 46T narrow/wide alloy, with alloy guard\nCassette:\n- Size: M, L, XL: Shimano Deore M4100, 11-42, 10-speed\nChain:\n- Size: M, L, XL: KMC E10\nPedal:\n- Size: M, L, XL: AgileEon City pedals\nMax chainring size: 1x: 48T\n\nComponents\nSaddle: AgileEon Commuter Comp\nSeatpost: AgileEon Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\nHandlebar:\n- Size: M: AgileEon alloy, 31.8mm, 15mm rise, 600mm width\n- Size: L, XL: AgileEon alloy, 31.8mm, 15mm rise, 660mm width\nGrips: AgileEon Satellite Elite, alloy lock-on\nStem:\n- Size: M: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length\n- Size: L: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length\n- Size: XL: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\nHeadset:\n- Size: M, L, XL: AgileEon IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\nBrake rotor: Shimano RT56, 6-bolt, 180mm\nRotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\nBattery: AgileEon PowerTube 625Wh\nCharger: AgileEon standard 4A, 100-240V\nMotor: AgileEon Performance Speed, 85 Nm, 28 mph / 45 kph\nLight:\n- Size: M, L, XL: AgileEon taillight, 50 lumens\n- Size: M, L, XL: AgileEon headlight, 500 lumens\nKickstand:\n- Size: M, L, XL: Rear mount, alloy\n- Size: M, L, XL: Adjustable length alloy kickstand\nCargo rack: AgileEon integrated rear rack, aluminum\nFender: AgileEon custom aluminum\n\nWeight\nWeight: M - 25.54 kg / 56.3 lbs\nWeight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", - "price": 3499.99, - "tags": [ - "bicycle", - "road bike", - "professional" - ] - }, - { - "name": "Stealth R1X Pro", - "shortDescription": "Stealth R1X Pro is a high-performance carbon road bike designed for riders who crave speed and exceptional handling. With its aerodynamic tube shaping, disc brakes, and lightweight carbon wheels, the Stealth R1X Pro offers unparalleled performance for competitive road cycling.", - "description": "## Overview\nIt's right for you if...\nYou're a competitive cyclist looking for a road bike that offers superior performance in terms of speed, handling, and aerodynamics. You want a complete package that includes lightweight carbon wheels, without the need for future upgrades.\n\nThe tech you get\nThe Stealth R1X Pro features a lightweight and aerodynamic carbon frame, an advanced carbon fork, high-performance Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes. The bike also comes equipped with cutting-edge Bontrager Aeolus Elite 35 carbon wheels.\n\nThe final word\nThe Stealth R1X Pro stands out with its combination of a fast and aerodynamic frame, high-end drivetrain, and top-of-the-line carbon wheels. Whether you're racing on local roads, participating in pro stage races, or engaging in hill climbing competitions, this bike is a formidable choice that delivers an exceptional riding experience.\n\n## Features\nSleek and aerodynamic design\nThe Stealth R1X Pro's aero tube shapes maximize speed and performance, making it faster on climbs and flats alike. The bike also features a streamlined Aeolus RSL bar/stem for improved front-end aerodynamics.\n\nDesigned for all riders\nThe Stealth R1X Pro is designed to provide an outstanding fit for riders of all genders, body types, riding styles, and abilities. It comes equipped with size-specific components to ensure a comfortable and efficient riding position for competitive riders.\n\n## Specifications\nFrameset\n- Frame: Ultralight carbon frame constructed with high-performance 500 Series ADV Carbon. It features Ride Tuned performance tube optimization, a tapered head tube, internal routing, DuoTrap S compatibility, flat mount disc brake mounts, and a 142x12mm thru axle.\n- Fork: Full carbon fork (Émonda SL) with a tapered carbon steerer, internal brake routing, flat mount disc brake mounts, and a 12x100mm thru axle.\n- Frame fit: H1.5 Race geometry.\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, and a 100x12mm thru axle.\n- Rear wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, Shimano 11/12-speed freehub, and a 142x12mm thru axle.\n- Front skewer: Bontrager Switch thru axle with a removable lever.\n- Rear skewer: Bontrager Switch thru axle with a removable lever.\n- Tire: Bontrager R2 Hard-Case Lite with an aramid bead, 60 tpi, and a size of 700x25c.\n- Maximum tire size: 28mm.\n\nDrivetrain\n- Shifter:\n - Size 47, 50, 52: Shimano Ultegra R8025 with short-reach levers, 11-speed.\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed.\n- Front derailleur: Shimano Ultegra R8000, braze-on.\n- Rear derailleur: Shimano Ultegra R8000, short cage, with a maximum cog size of 30T.\n- Crank:\n - Size 47: Shimano Ultegra R8000 with 52/36 chainrings and a 165mm length.\n - Size 50, 52: Shimano Ultegra R8000 with 52/36 chainrings and a 170mm length.\n - Size 54, 56, 58: Shimano Ultegra R8000 with 52/36 chainrings and a 172.5mm length.\n - Size 60, 62: Shimano Ultegra R8000 with 52/36 chainrings and a 175mm length.\n- Bottom bracket: Praxis T47 threaded bottom bracket with internal bearings.\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed.\n- Chain: Shimano Ultegra HG701, 11-speed.\n- Maximum chainring size: 1x - 50T, 2x - 53/39.\n\nComponents\n- Saddle: Bontrager Aeolus Comp with steel rails and a width of 145mm.\n- Seatpost:\n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap with a 20mm offset and a short length.\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap with a 20mm offset and a tall length.\n- Handlebar:\n - Size 47, 50: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 38cm.\n - Size 52: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 40cm.\n - Size 54, 56, 58: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 42cm.\n - Size 60, 62: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 44cm.\n- Handlebar tape: Bontrager Supertack Perf tape.\n- Stem:\n - Size 47: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 70mm.\n - Size 50: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 80mm.\n - Size 52, 54: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 90mm.\n - Size 56: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 100mm.\n - Size 58, 60, 62: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 110mm.\n- Brake: Shimano Ultegra hydraulic disc brakes with flat mount calipers.\n- Brake rotor: Shimano RT800 with centerlock mounting, 160mm diameter.\n\nWeight\n- Weight: 8.03 kg (17.71 lbs) for the 56cm frame.\n- Weight limit: The bike has a maximum total weight limit (combined weight of the bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\nPlease refer to the table below for the corresponding Stealth R1X Pro frame sizes, recommended rider height range, and inseam measurements:\n\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:--------------:|\n| 47 | 152 - 158 cm (5'0\") | 71 - 75 cm |\n| 50 | 158 - 163 cm (5'2\") | 74 - 77 cm |\n| 52 | 163 - 168 cm (5'4\") | 76 - 79 cm |\n| 54 | 168 - 174 cm (5'6\") | 78 - 82 cm |\n| 56 | 174 - 180 cm (5'9\") | 81 - 85 cm |\n| 58 | 180 - 185 cm (5'11\") | 84 - 87 cm |\n| 60 | 185 - 190 cm (6'1\") | 86 - 90 cm |\n| 62 | 190 - 195 cm (6'3\") | 89 - 92 cm |\n\n## Geometry\nThe table below provides the geometry measurements for each frame size of the Stealth R1X Pro:\n\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|-------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", - "price": 2999.99, - "tags": [ - "bicycle", - "mountain bike", - "professional" - ] - }, - { - "name": "Avant SLR 6 Disc Pro", - "shortDescription": "Avant SLR 6 Disc Pro is a high-performance carbon road bike designed for riders who prioritize speed and handling. With its aero tube shaping, disc brakes, and lightweight carbon wheels, it offers the perfect balance of speed and control.", - "description": "## Overview\nIt's right for you if...\nYou're a rider who values exceptional performance on fast group rides and races, and you want a complete package that includes lightweight carbon wheels. The Avant SLR 6 Disc Pro is designed to provide the speed and aerodynamics you need to excel on any road.\n\nThe tech you get\nThe Avant SLR 6 Disc Pro features a lightweight 500 Series ADV Carbon frame and fork, Bontrager Aeolus Elite 35 carbon wheels, a full Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes.\n\nThe final word\nThe standout feature of this bike is the combination of its aero frame, high-performance drivetrain, and top-quality carbon wheels. Whether you're racing, tackling challenging climbs, or participating in professional stage races, the Avant SLR 6 Disc Pro is a worthy choice that will enhance your performance.\n\n## Features\nAll-new aero design\nThe Avant SLR 6 Disc Pro features innovative aero tube shapes that provide an advantage in all riding conditions, whether it's climbing or riding on flat roads. Additionally, it is equipped with a sleek new Aeolus RSL bar/stem that enhances front-end aero performance.\n\nAwesome bikes for everyone\nThe Avant SLR 6 Disc Pro is designed with the belief that every rider, regardless of gender, body type, riding style, or ability, deserves a great bike. It is equipped with size-specific components that ensure a perfect fit for competitive riders of all genders.\n\n## Specifications\nFrameset\n- Frame: Ultralight 500 Series ADV Carbon, Ride Tuned performance tube optimization, tapered head tube, internal routing, DuoTrap S compatible, flat mount disc, 142x12mm thru axle\n- Fork: Avant SL full carbon, tapered carbon steerer, internal brake routing, flat mount disc, 12x100mm thru axle\n- Frame fit: H1.5 Race\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x12mm thru axle\n- Rear wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11/12-speed freehub, 142x12mm thru axle\n- Front skewer: Bontrager Switch thru axle, removable lever\n- Rear skewer: Bontrager Switch thru axle, removable lever\n- Tire: Bontrager R2 Hard-Case Lite, aramid bead, 60 tpi, 700x25c\n- Max tire size: 28mm\n\nDrivetrain\n- Shifter: \n - Size 47, 50, 52: Shimano Ultegra R8025, short-reach lever, 11-speed\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed\n- Front derailleur: Shimano Ultegra R8000, braze-on\n- Rear derailleur: Shimano Ultegra R8000, short cage, 30T max cog\n- Crank: \n - Size 47: Shimano Ultegra R8000, 52/36, 165mm length\n - Size 50, 52: Shimano Ultegra R8000, 52/36, 170mm length\n - Size 54, 56, 58: Shimano Ultegra R8000, 52/36, 172.5mm length\n - Size 60, 62: Shimano Ultegra R8000, 52/36, 175mm length\n- Bottom bracket: Praxis, T47 threaded, internal bearing\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed\n- Chain: Shimano Ultegra HG701, 11-speed\n- Max chainring size: 1x: 50T, 2x: 53/39\n\nComponents\n- Saddle: Bontrager Aeolus Comp, steel rails, 145mm width\n- Seatpost: \n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap, 20mm offset, short length\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap, 20mm offset, tall length\n- Handlebar: \n - Size 47, 50: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 38cm width\n - Size 52: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 40cm width\n - Size 54, 56, 58: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 42cm width\n - Size 60, 62: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 44cm width\n- Handlebar tape: Bontrager Supertack Perf tape\n- Stem: \n - Size 47: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 70mm length\n - Size 50: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 80mm length\n - Size 52, 54: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 90mm length\n - Size 56: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 100mm length\n - Size 58, 60, 62: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 110mm length\n- Brake: Shimano Ultegra hydraulic disc, flat mount\n- Brake rotor: Shimano RT800, centerlock, 160mm\n\nWeight\n- Weight: 56 - 8.03 kg / 17.71 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 47 | 152 - 158 cm 5'0\" - 5'2\" | 71 - 75 cm 28\" - 30\" |\n| 50 | 158 - 163 cm 5'2\" - 5'4\" | 74 - 77 cm 29\" - 30\" |\n| 52 | 163 - 168 cm 5'4\" - 5'6\" | 76 - 79 cm 30\" - 31\" |\n| 54 | 168 - 174 cm 5'6\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| 56 | 174 - 180 cm 5'9\" - 5'11\" | 81 - 85 cm 32\" - 33\" |\n| 58 | 180 - 185 cm 5'11\" - 6'1\" | 84 - 87 cm 33\" - 34\" |\n| 60 | 185 - 190 cm 6'1\" - 6'3\" | 86 - 90 cm 34\" - 35\" |\n| 62 | 190 - 195 cm 6'3\" - 6'5\" | 89 - 92 cm 35\" - 36\" |\n\n## Geometry\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (w/short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (w/short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (w/tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (w/tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", - "price": 999.99, - "tags": [ - "bicycle", - "city bike", - "professional" - ] - } -] \ No newline at end of file + { + "name": "E-Adrenaline 8.0 EX1", + "shortDescription": "a versatile and comfortable e-MTB designed for adrenaline enthusiasts who want to explore all types of terrain. It features a powerful motor and advanced suspension to provide a smooth and responsive ride, with a variety of customizable settings to fit any rider's needs.", + "description": "## Overview\r\nIt's right for you if...\r\nYou want to push your limits on challenging trails and terrain, with the added benefit of an electric assist to help you conquer steep climbs and rough terrain. You also want a bike with a comfortable and customizable fit, loaded with high-quality components and technology.\r\n\r\nThe tech you get\r\nA lightweight, full ADV Mountain Carbon frame with a customizable geometry, including an adjustable head tube and chainstay length. A powerful and efficient motor with a 375Wh battery that can assist up to 28 mph when it's on, and provides a smooth and seamless transition when it's off. A SRAM EX1 8-speed drivetrain, a RockShox Lyrik Ultimate fork, and a RockShox Super Deluxe Ultimate rear shock.\r\n\r\nThe final word\r\nOur E-Adrenaline 8.0 EX1 is the perfect bike for adrenaline enthusiasts who want to explore all types of terrain. It's versatile, comfortable, and loaded with advanced technology to provide a smooth and responsive ride, no matter where your adventures take you.\r\n\r\n\r\n## Features\r\nVersatile and customizable\r\nThe E-Adrenaline 8.0 EX1 features a customizable geometry, including an adjustable head tube and chainstay length, so you can fine-tune your ride to fit your needs and preferences. It also features a variety of customizable settings, including suspension tuning, motor assistance levels, and more.\r\n\r\nPowerful and efficient\r\nThe bike is equipped with a powerful and efficient motor that provides a smooth and seamless transition between human power and electric assist. It can assist up to 28 mph when it's on, and provides zero drag when it's off.\r\n\r\nAdvanced suspension\r\nThe E-Adrenaline 8.0 EX1 features a RockShox Lyrik Ultimate fork and a RockShox Super Deluxe Ultimate rear shock, providing advanced suspension technology to absorb shocks and bumps on any terrain. The suspension is also customizable to fit your riding style and preferences.\r\n\r\n\r\n## Specs\r\nFrameset\r\nFrame ADV Mountain Carbon main frame & stays, adjustable head tube and chainstay length, tapered head tube, Knock Block, Control Freak internal routing, Boost148, 150mm travel\r\nFork RockShox Lyrik Ultimate, DebonAir spring, Charger 2.1 RC2 damper, remote lockout, tapered steerer, 42mm offset, Boost110, 15mm Maxle Stealth, 160mm travel\r\nShock RockShox Super Deluxe Ultimate, DebonAir spring, Thru Shaft 3-position damper, 230x57.5mm\r\n\r\nWheels\r\nWheel front Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 6-bolt, Boost110, 15mm thru axle\r\nWheel rear Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 54T Rapid Drive, 6-bolt, Shimano MicroSpline freehub, Boost148, 12mm thru axle\r\nSkewer rear Bontrager Switch thru axle, removable lever\r\nTire Bontrager XR5 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.50''\r\nTire part Bontrager TLR sealant, 6oz\r\n\r\nDrivetrain\r\nShifter SRAM EX1, 8 speed\r\nRear derailleur SRAM EX1, 8 speed\r\nCrank Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nChainring SRAM EX1, 18T, steel\r\nCassette SRAM EX1, 11-48, 8 speed\r\nChain SRAM EX1, 8 speed\r\n\r\nComponents\r\nSaddle Bontrager Arvada, hollow chromoly rails, 138mm width\r\nSeatpost Bontrager Line Elite Dropper, internal routing, 31.6mm\r\nHandlebar Bontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\r\nGrips Bontrager XR Trail Elite, alloy lock-on\r\nStem Bontrager Line Pro, 35mm, Knock Block, Blendr compatible, 0 degree, 50mm length\r\nHeadset Knock Block Integrated, 62-degree radius, cartridge bearing, 1-1\/8'' top, 1.5'' bottom\r\nBrake SRAM G2 RSC hydraulic disc, carbon levers\r\nBrake rotor SRAM Centerline, centerlock, round edge, 200mm\r\n\r\nAccessories\r\nE-bike system Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nBattery Bosch PowerTube 625, 625Wh\r\nCharger Bosch 4A standard charger\r\nController Bosch Kiox with Anti-theft solution, Bluetooth connectivity, 1.9'' display\r\nTool Bontrager Switch thru axle, removable lever\r\n\r\nWeight\r\nWeight M - 20.25 kg \/ 44.6 lbs (with TLR sealant, no tubes)\r\nWeight limit This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\r\n\r\n## Sizing & fit\r\n\r\n| Size | Rider Height | Inseam |\r\n|:----:|:------------------------:|:--------------------:|\r\n| S | 155 - 170 cm 5'1\" - 5'7\" | 73 - 80 cm 29\" - 31.5\" |\r\n| M | 163 - 178 cm 5'4\" - 5'10\" | 77 - 83 cm 30.5\" - 32.5\" |\r\n| L | 176 - 191 cm 5'9\" - 6'3\" | 83 - 89 cm 32.5\" - 35\" |\r\n| XL | 188 - 198 cm 6'2\" - 6'6\" | 88 - 93 cm 34.5\" - 36.5\" |\r\n\r\n\r\n## Geometry\r\n\r\nAll measurements provided in cm unless otherwise noted.\r\nSizing table\r\n| Frame size letter | S | M | L | XL |\r\n|---------------------------|-------|-------|-------|-------|\r\n| Actual frame size | 15.8 | 17.8 | 19.8 | 21.8 |\r\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\r\n| A \u2014 Seat tube | 40.0 | 42.5 | 47.5 | 51.0 |\r\n| B \u2014 Seat tube angle | 72.5\u00B0 | 72.8\u00B0 | 73.0\u00B0 | 73.0\u00B0 |\r\n| C \u2014 Head tube length | 9.5 | 10.5 | 11.0 | 11.5 |\r\n| D \u2014 Head angle | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 |\r\n| E \u2014 Effective top tube | 59.0 | 62.0 | 65.0 | 68.0 |\r\n| F \u2014 Bottom bracket height | 32.5 | 32.5 | 32.5 | 32.5 |\r\n| G \u2014 Bottom bracket drop | 5.5 | 5.5 | 5.5 | 5.5 |\r\n| H \u2014 Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\r\n| I \u2014 Offset | 4.5 | 4.5 | 4.5 | 4.5 |\r\n| J \u2014 Trail | 11.0 | 11.0 | 11.0 | 11.0 |\r\n| K \u2014 Wheelbase | 113.0 | 117.0 | 120.0 | 123.0 |\r\n| L \u2014 Standover | 77.0 | 77.0 | 77.0 | 77.0 |\r\n| M \u2014 Frame reach | 41.0 | 44.5 | 47.5 | 50.0 |\r\n| N \u2014 Frame stack | 61.0 | 62.0 | 62.5 | 63.0 |", + "price": 1499.99, + "tags": [ + "bicycle" + ] + }, + { + "name": "Enduro X Pro", + "shortDescription": "The Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame and top-of-the-line components, this bike is ready to tackle any trail, from technical downhill descents to grueling uphill climbs.", + "text": "## Overview\nIt's right for you if...\nYou're an experienced mountain biker who wants a high-performance bike that can handle any terrain. You want a bike with the best components available, including a full carbon frame, suspension system, and hydraulic disc brakes.\n\nThe tech you get\nOur top-of-the-line full carbon frame with aggressive geometry and a slack head angle for maximum control. It's equipped with a Fox Factory suspension system with 170mm of travel in the front and 160mm in the rear, a Shimano XTR 12-speed drivetrain, and hydraulic disc brakes for maximum stopping power. The bike also features a dropper seatpost for easy adjustments on the fly.\n\nThe final word\nThe Enduro X Pro is the ultimate mountain bike for riders who demand the best. With its full carbon frame, top-of-the-line components, and aggressive geometry, this bike is ready to take on any trail. Whether you're a seasoned pro or just starting out, the Enduro X Pro will help you take your riding to the next level.\n\n## Features\nFull carbon frame\nAggressive geometry with a slack head angle\nFox Factory suspension system with 170mm of travel in the front and 160mm in the rear\nShimano XTR 12-speed drivetrain\nHydraulic disc brakes for maximum stopping power\nDropper seatpost for easy adjustments on the fly\n\n## Specifications\nFrameset\nFrame\tFull carbon frame\nFork\tFox Factory suspension system with 170mm of travel\nRear suspension\tFox Factory suspension system with 160mm of travel\n\nWheels\nWheel size\t27.5\" or 29\"\nTires\tTubeless-ready Maxxis tires\n\nDrivetrain\nShifters\tShimano XTR 12-speed\nFront derailleur\tN/A\nRear derailleur\tShimano XTR\nCrankset\tShimano XTR\nCassette\tShimano XTR 12-speed\nChain\tShimano XTR\n\nComponents\nBrakes\tHydraulic disc brakes\nHandlebar\tAlloy handlebar\nStem\tAlloy stem\nSeatpost\tDropper seatpost\n\nAccessories\nPedals\tNot included\n\nWeight\nWeight\tApproximately 27-29 lbs\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 5'4\" - 5'8\" (162-172cm) |\n| M | 5'8\" - 5'11\" (172-180cm) |\n| L | 5'11\" - 6'3\" (180-191cm) |\n| XL | 6'3\" - 6'6\" (191-198cm) |\n\n## Geometry\n| Size | S | M | L | XL |\n|:----:|:---------------:|:---------------:|:-----------------:|:---------------:|\n| A - Seat tube length | 390mm | 425mm | 460mm | 495mm |\n| B - Effective top tube length | 585mm | 610mm | 635mm | 660mm |\n| C - Head tube angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| D - Seat tube angle | 76° | 76° | 76° | 76° |\n| E - Chainstay length | 435mm | 435mm | 435mm | 435mm |\n| F - Head tube length | 100mm | 110mm | 120mm | 130mm |\n| G - BB drop | 20mm | 20mm | 20mm | 20mm |\n| H - Wheelbase | 1155mm | 1180mm | 1205mm | 1230mm |\n| I - Standover height | 780mm | 800mm | 820mm | 840mm |\n| J - Reach | 425mm | 450mm | 475mm | 500mm |\n| K - Stack | 610mm | 620mm | 630mm | 640mm |", + "price": 599.99, + "tags": [ + "bicycle" + ] + }, + { + "name": "Blaze X1", + "shortDescription": "Blaze X1 is a high-performance road bike that offers superior speed and agility, making it perfect for competitive racing or fast-paced group rides. The bike features a lightweight carbon frame, aerodynamic tube shapes, a 12-speed Shimano Ultegra drivetrain, and hydraulic disc brakes for precise stopping power. With its sleek design and cutting-edge technology, Blaze X1 is a bike that is built to perform and dominate on any road.", + "description": "## Overview\nIt's right for you if...\nYou're a competitive road cyclist or an enthusiast who enjoys fast-paced group rides. You want a bike that is lightweight, agile, and delivers exceptional speed.\n\nThe tech you get\nBlaze X1 features a lightweight carbon frame with a tapered head tube and aerodynamic tube shapes for maximum speed and efficiency. The bike is equipped with a 12-speed Shimano Ultegra drivetrain for smooth and precise shifting, Shimano hydraulic disc brakes for powerful and reliable stopping power, and Bontrager Aeolus Elite 35 carbon wheels for increased speed and agility.\n\nThe final word\nBlaze X1 is a high-performance road bike that is designed to deliver exceptional speed and agility. With its cutting-edge technology and top-of-the-line components, it's a bike that is built to perform and dominate on any road.\n\n## Features\nSpeed and efficiency\nBlaze X1's lightweight carbon frame and aerodynamic tube shapes offer maximum speed and efficiency, allowing you to ride faster and farther with ease.\n\nPrecision stopping power\nShimano hydraulic disc brakes provide precise and reliable stopping power, even in wet or muddy conditions.\n\nAgility and control\nBontrager Aeolus Elite 35 carbon wheels make Blaze X1 incredibly agile and responsive, allowing you to navigate tight turns and corners with ease.\n\nSmooth and precise shifting\nThe 12-speed Shimano Ultegra drivetrain offers smooth and precise shifting, so you can easily find the right gear for any terrain.\n\n## Specifications\nFrameset\nFrame\tADV Carbon, tapered head tube, BB90, direct mount rim brakes, internal cable routing, DuoTrap S compatible, 130x9mm QR\nFork\tADV Carbon, tapered steerer, direct mount rim brakes, internal brake routing, 100x9mm QR\n\nWheels\nWheel front\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x9mm QR\nWheel rear\tBontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11-speed freehub, 130x9mm QR\nTire front\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nTire rear\tBontrager R3 Hard-Case Lite, aramid bead, 120 tpi, 700x25c\nMax tire size\t25c Bontrager tires (with at least 4mm of clearance to frame)\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 12 speed\nFront derailleur\tShimano Ultegra R8000, braze-on\nRear derailleur\tShimano Ultegra R8000, short cage, 30T max cog\nCrank\tSize: 50, 52, 54\nShimano Ultegra R8000, 50/34 (compact), 170mm length\nSize: 56, 58, 60, 62\nShimano Ultegra R8000, 50/34 (compact), 172.5mm length\nBottom bracket\tBB90, Shimano press-fit\nCassette\tShimano Ultegra R8000, 11-30, 12 speed\nChain\tShimano Ultegra HG701, 12 speed\n\nComponents\nSaddle\tBontrager Montrose Elite, titanium rails, 138mm width\nSeatpost\tBontrager carbon seatmast cap, 20mm offset\nHandlebar\tBontrager Elite Aero VR-CF, alloy, 31.8mm, internal cable routing, 40cm width\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Elite, 31.8mm, Blendr-compatible, 7 degree, 80mm length\nBrake Shimano Ultegra hydraulic disc brake\n\nWeight\nWeight\t56 - 8.91 kg / 19.63 lbs (with tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider height |\n|------|-------------|\n| 50 | 162-166cm |\n| 52 | 165-170cm |\n| 54 | 168-174cm |\n| 56 | 174-180cm |\n| 58 | 179-184cm |\n| 60 | 184-189cm |\n| 62 | 189-196cm |\n\n## Geometry\n| Frame size | 50cm | 52cm | 54cm | 56cm | 58cm | 60cm | 62cm |\n|------------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A - Seat tube | 443mm | 460mm | 478mm | 500mm | 520mm | 540mm | 560mm |\n| B - Seat tube angle | 74.1° | 73.9° | 73.7° | 73.4° | 73.2° | 73.0° | 72.8° |\n| C - Head tube length | 100mm | 110mm | 130mm | 150mm | 170mm | 190mm | 210mm |\n| D - Head angle | 71.4° | 72.0° | 72.5° | 73.0° | 73.3° | 73.6° | 73.8° |\n| E - Effective top tube | 522mm | 535mm | 547mm | 562mm | 577mm | 593mm | 610mm |\n| F - Bottom bracket height | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm | 268mm |\n| G - Bottom bracket drop | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm | 69mm |\n| H - Chainstay length | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm | 410mm |\n| I - Offset | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm | 50mm |\n| J - Trail | 65mm | 62mm | 59mm | 56mm | 55mm | 53mm | 52mm |\n| K - Wheelbase | 983mm | 983mm | 990mm | 1005mm | 1019mm | 1036mm | 1055mm |\n| L - Standover | 741mm | 765mm | 787mm | 806mm | 825mm | 847mm | 869mm |", + "price": 799.99, + "tags": [ + "bicycle", + "mountain bike" + ] + }, + { + "name": "Celerity X5", + "shortDescription": "Celerity X5 is a versatile and reliable road bike that is designed for experienced and amateur riders alike. It's designed to provide smooth and comfortable rides over long distances. With an ultra-lightweight and responsive carbon fiber frame, Shimano 105 groupset, hydraulic disc brakes, and 28mm wide tires, this bike ensures efficient power transfer, precise handling, and superior stopping power.", + "description": "## Overview\n\nIt's right for you if... \nYou are looking for a high-performance road bike that offers a perfect balance of speed, comfort, and control. You enjoy long-distance rides and need a bike that is designed to handle various road conditions with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nCelerity X5 is equipped with a full carbon fiber frame that ensures maximum strength and durability while keeping the weight down. It features a Shimano 105 groupset with 11-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power, and 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that offers comfort, speed, and control, Celerity X5 is the perfect choice. With its lightweight carbon fiber frame, reliable components, and advanced technology, this bike is designed to help you enjoy long-distance rides with ease.\n\n## Features \n\nLightweight and responsive \nCelerity X5 comes with a full carbon fiber frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon seat post provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tCelerity X5 Full Carbon Fiber Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tCelerity X5 Full Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tCelerity X5 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano 105 R7025 Hydraulic Disc Shifters \nFront Derailleur\tShimano 105 R7000 \nRear Derailleur\tShimano 105 R7000 \nCrankset\tShimano 105 R7000 50-34T \nBottom Bracket\tShimano BB72-41B \nCassette\tShimano 105 R7000 11-30T \nChain\tShimano HG601 11-Speed Chain \n\nComponents \nSaddle\tSelle Royal Asphalt Saddle \nSeatpost\tCelerity X5 Carbon Seatpost \nHandlebar\tCelerity X5 Compact Handlebar \nStem\tCelerity X5 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano 105 R7025 Hydraulic Disc Brakes \nRotors\tShimano SM-RT70 160mm Rotors \n\nAccessories \nPedals\tCelerity X5 Road Pedals \n\nWeight \nWeight\t8.2 kg / 18.1 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", + "price": 399.99, + "tags": [ + "bicycle", + "city bike" + ] + }, + { + "name": "Velocity V8", + "shortDescription": "Velocity V8 is a high-performance road bike that is designed to deliver speed, agility, and control on the road. With its lightweight aluminum frame, carbon fiber fork, Shimano Tiagra groupset, and hydraulic disc brakes, this bike is perfect for experienced riders who are looking for a fast and responsive bike that can handle various road conditions.", + "description": "## Overview\n\nIt's right for you if... \nYou are an experienced rider who is looking for a high-performance road bike that is lightweight, agile, and responsive. You want a bike that can handle long-distance rides, steep climbs, and fast descents with ease. You also appreciate the latest technology and reliable components that make your riding experience more enjoyable.\n\nThe tech you get \nVelocity V8 features a lightweight aluminum frame with a carbon fiber fork that ensures a comfortable ride without sacrificing stiffness and power transfer. It comes with a Shimano Tiagra groupset with 10-speed gearing for precise and efficient shifting. Hydraulic disc brakes offer superior stopping power in all weather conditions, while 28mm wide tires provide comfort and stability on various road surfaces. Internal cable routing enhances the bike's sleek appearance.\n\nThe final word \nIf you are looking for a high-performance road bike that is lightweight, fast, and responsive, Velocity V8 is the perfect choice. With its lightweight aluminum frame, reliable components, and advanced technology, this bike is designed to help you enjoy fast and comfortable rides on the road.\n\n## Features \n\nLightweight and responsive \nVelocity V8 comes with a lightweight aluminum frame that is not only lightweight but also responsive, providing excellent handling and control.\n\nHydraulic disc brakes \nThis bike is equipped with hydraulic disc brakes that provide superior stopping power in all weather conditions, ensuring your safety and confidence on the road.\n\nComfortable rides \nThe 28mm wide tires and carbon fork provide ample cushioning, ensuring a smooth and comfortable ride over long distances.\n\nSleek appearance \nThe bike's internal cable routing enhances its sleek appearance while also protecting the cables from the elements, ensuring smooth shifting for longer periods.\n\n## Specifications \n\nFrameset \nFrame\tVelocity V8 Aluminum Frame, Internal Cable Routing, Tapered Headtube, Press Fit Bottom Bracket, 12x142mm Thru-Axle \nFork\tVelocity V8 Carbon Fiber Fork, Internal Brake Routing, 12x100mm Thru-Axle \n\nWheels \nWheelset\tAlexRims CXD7 Wheelset \nTire\tSchwalbe Durano Plus 700x28mm \nInner Tubes\tSchwalbe SV15 700x18-28mm \nSkewers\tVelocity V8 Thru-Axle Skewers \n\nDrivetrain \nShifter\tShimano Tiagra Hydraulic Disc Shifters \nFront Derailleur\tShimano Tiagra \nRear Derailleur\tShimano Tiagra \nCrankset\tShimano Tiagra 50-34T \nBottom Bracket\tShimano BB-RS500-PB \nCassette\tShimano Tiagra 11-32T \nChain\tShimano HG54 10-Speed Chain \n\nComponents \nSaddle\tVelocity V8 Saddle \nSeatpost\tVelocity V8 Aluminum Seatpost \nHandlebar\tVelocity V8 Compact Handlebar \nStem\tVelocity V8 Aluminum Stem \nHeadset\tFSA Orbit IS-2 \n\nBrakes \nBrakes\tShimano Tiagra Hydraulic Disc Brakes \nRotors\tShimano SM-RT64 160mm Rotors \n\nAccessories \nPedals\tVelocity V8 Road Pedals \n\nWeight \nWeight\t9.4 kg / 20.7 lbs \nWeight Limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 120 kg (265 lbs).\n\n## Sizing \n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 49 | 155 - 162 cm 5'1\" - 5'4\" | 71 - 76 cm 28\" - 30\" |\n| 52 | 162 - 170 cm 5'4\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| 54 | 170 - 178 cm 5'7\" - 5'10\" | 77 - 83 cm 30\" - 32\" |\n| 56 | 178 - 185 cm 5'10\" - 6'1\" | 82 - 88 cm 32\" - 34\" |\n| 58 | 185 - 193 cm 6'1\" - 6'4\" | 86 - 92 cm 34\" - 36\" |\n| 61 | 193 - 200 cm 6'4\" - 6'7\" | 90 - 95 cm 35\" - 37\" |\n\n## Geometry \n| Frame size number | 49 cm | 52 cm | 54 cm | 56 cm | 58 cm | 61 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 47.5 | 50.0 | 52.0 | 54.0 | 56.0 | 58.5 |\n| B — Seat tube angle | 75.0° | 74.5° | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 12.0 | 14.5 | 16.5 | 18.5 | 20.5 | 23.5 |\n| D — Head angle | 70.0° | 71.0° | 71.5° | 72.0° | 72.5° | 72.5° |\n| E — Effective top tube | 52.5 | 53.5 | 54.5 | 56.0 | 57.5 | 59.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 | 41.5 |\n| K — Wheelbase | 98.4 | 98.9 | 99.8 | 100.8 | 101.7 | 103.6 |\n| L — Standover | 72.0 | 74.0 | 76.0 | 78.0 | 80.0 | 82.0 |\n| M — Frame reach | 36.2 | 36.8 | 37.3 | 38.1 | 38.6 | 39.4 |\n| N — Frame stack | 52.0 | 54.3 | 56.2 | 58.1 | 59.8 | 62.4 |\n| Saddle rail height min | 67.0 | 69.5 | 71.5 | 74.0 | 76.0 | 78.0 |\n| Saddle rail height max | 75.0 | 77.5 | 79.5 | 82.0 | 84.0 | 86.0 |", + "price": 1899.99, + "tags": [ + "bicycle", + "electric bike" + ] + }, + { + "name": "VeloCore X9 eMTB", + "shortDescription": "The VeloCore X9 eMTB is a light, agile and versatile electric mountain bike designed for adventure and performance. Its purpose-built frame and premium components offer an exhilarating ride experience on both technical terrain and smooth singletrack.", + "description": "## Overview\nIt's right for you if...\nYou love exploring new trails and testing your limits on challenging terrain. You want an electric mountain bike that offers power when you need it, without sacrificing performance or agility. You're looking for a high-quality bike with top-notch components and a sleek design.\n\nThe tech you get\nA lightweight, full carbon frame with custom geometry, a 140mm RockShox Pike Ultimate fork with Charger 2.1 damper, and a Fox Float DPS Performance shock. A Shimano STEPS E8000 motor and 504Wh battery that provide up to 62 miles of range and 20 mph assistance. A Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels.\n\nThe final word\nThe VeloCore X9 eMTB delivers power and agility in equal measure. It's a versatile and capable electric mountain bike that can handle any trail with ease. With premium components, a custom carbon frame, and a sleek design, this bike is built for adventure.\n\n## Features\nAgile and responsive\n\nThe VeloCore X9 eMTB is designed to be nimble and responsive on the trail. Its custom carbon frame offers a perfect balance of stiffness and compliance, while the suspension system provides smooth and stable performance on technical terrain.\n\nPowerful and efficient\n\nThe Shimano STEPS E8000 motor and 504Wh battery provide up to 62 miles of range and 20 mph assistance. The motor delivers smooth and powerful performance, while the battery offers reliable and consistent power for long rides.\n\nCustomizable ride experience\n\nThe VeloCore X9 eMTB comes with an intuitive and customizable Shimano STEPS display that allows you to adjust the level of assistance, monitor your speed and battery life, and customize your ride experience to suit your needs.\n\nPremium components\n\nThe VeloCore X9 eMTB is equipped with high-end components, including a Shimano XT 12-speed drivetrain, Shimano SLX brakes, and DT Swiss wheels. These components offer reliable and precise performance, allowing you to push your limits with confidence.\n\n## Specs\nFrameset\nFrame\tVeloCore carbon fiber frame, Boost, tapered head tube, internal cable routing, 140mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 damper, DebonAir spring, 15x110mm Boost Maxle Ultimate, 46mm offset, 140mm travel\nShock\tFox Float DPS Performance, EVOL, 3-position adjust, Kashima Coat, 210x50mm\n\nWheels\nWheel front\tDT Swiss XM1700 Spline, 30mm internal width, 15x110mm Boost axle\nWheel rear\tDT Swiss XM1700 Spline, 30mm internal width, Shimano Microspline driver, 12x148mm Boost axle\nTire front\tMaxxis Minion DHF, 29x2.5\", EXO+ casing, tubeless ready\nTire rear\tMaxxis Minion DHR II, 29x2.4\", EXO+ casing, tubeless ready\n\nDrivetrain\nShifter\tShimano XT M8100, 12-speed\nRear derailleur\tShimano XT M8100, Shadow Plus, long cage, 51T max cog\nCrankset\tShimano STEPS E8000, 165mm length, 34T chainring\nCassette\tShimano XT M8100, 10-51T, 12-speed\nChain\tShimano CN-M8100, 12-speed\nPedals\tNot included\n\nComponents\nSaddle\tBontrager Arvada, hollow chromoly rails\nSeatpost\tDrop Line, internal routing, 31.6mm (15.5: 100mm, 17.5 & 18.5: 125mm, 19.5 & 21.5: 150mm)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nStem\tBontrager Line Pro, 35mm, Knock Block, 0 degree, 50mm length\nGrips\tBontrager XR Trail Elite, alloy lock-on\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrakeset\tShimano SLX M7120, 4-piston hydraulic disc\n\nAccessories\nBattery\tShimano STEPS BT-E8010, 504Wh\nCharger\tShimano STEPS EC-E8004, 4A\nController\tShimano STEPS E8000 display\nBike weight\tM - 22.5 kg / 49.6 lbs (with tubes)\n\n## Sizing & fit\n\n| Size | Rider Height |\n|:----:|:------------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" |\n| M | 170 - 178 cm 5'7\" - 5'10\"|\n| L | 178 - 186 cm 5'10\" - 6'1\"|\n| XL | 186 - 196 cm 6'1\" - 6'5\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| A — Seat tube | 40.6 | 43.2 | 47.0 | 51.0 |\n| B — Seat tube angle | 75.0° | 75.0° | 75.0° | 75.0° |\n| C — Head tube length | 9.6 | 10.6 | 11.6 | 12.6 |\n| D — Head angle | 66.5° | 66.5° | 66.5° | 66.5° |\n| E — Effective top tube | 60.4 | 62.6 | 64.8 | 66.9 |\n| F — Bottom bracket height | 33.2 | 33.2 | 33.2 | 33.2 |\n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 |\n| H — Chainstay length | 45.5 | 45.5 | 45.5 | 45.5 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 11.9 | 11.9 | 11.9 | 11.9 |\n| K — Wheelbase | 117.0 | 119.3 | 121.6 | 123.9 |\n| L — Standover | 75.9 | 75.9 | 78.6 | 78.6 |\n| M — Frame reach | 43.6 | 45.6 | 47.6 | 49.6 |\n| N — Frame stack | 60.5 | 61.5 | 62.4 | 63.4 |", + "price": 1299.99, + "tags": [ + "bicycle", + "touring bike" + ] + }, + { + "name": "Zephyr 8.8 GX Eagle AXS Gen 3", + "shortDescription": "Zephyr 8.8 GX Eagle AXS is a light and nimble full-suspension mountain bike. It's designed to handle technical terrain with ease and has a smooth and efficient ride feel. The sleek and powerful Bosch Performance Line CX motor and removable Powertube battery provide a boost to your pedaling and give you long-lasting riding time. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.", + "description": "## Overview\nIt's right for you if...\nYou're an avid mountain biker looking for a high-performance e-MTB that can tackle challenging trails. You want a bike with a powerful motor, efficient suspension, and advanced technology to enhance your riding experience. You also need a bike that's reliable and durable for long-lasting use.\n\nThe tech you get\nA lightweight, full carbon frame with 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. A Bosch Performance Line CX motor and removable Powertube 625Wh battery that can assist up to 20mph when it's on and gives zero drag when it's off, plus an easy-to-use handlebar-mounted Bosch Purion controller. A SRAM GX Eagle AXS wireless electronic drivetrain, a RockShox Reverb Stealth dropper, and DT Swiss HX1501 Spline One wheels.\n\nThe final word\nZephyr 8.8 GX Eagle AXS is a high-performance e-MTB that's designed to handle technical terrain with ease. With a powerful Bosch motor and long-lasting battery, you can conquer challenging climbs and enjoy long rides. The bike also features high-end components and advanced technology for an ultimate mountain biking experience.\n\n## Features\nPowerful motor\n\nThe Bosch Performance Line CX motor provides a boost to your pedaling and can assist up to 20mph. It has four power modes and a walk-assist function for easy navigation on steep climbs. The motor is also reliable and durable for long-lasting use.\n\nEfficient suspension\n\nZephyr 8.8 has a 150mm of rear travel and a 160mm RockShox Pike Ultimate fork with Charger 2.1 RCT3 damper, remote lockout, and DebonAir spring. The suspension is efficient and responsive, allowing you to handle technical terrain with ease.\n\nRemovable battery\n\nThe Powertube 625Wh battery is removable for easy charging and storage. It provides long-lasting riding time and can be replaced with a spare battery for even longer rides. The battery is also durable and weather-resistant for all-season riding.\n\nAdvanced technology\n\nZephyr 8.8 is equipped with advanced technology, including a Bosch Purion controller for easy motor control, a SRAM GX Eagle AXS wireless electronic drivetrain for precise shifting, and a RockShox Reverb Stealth dropper for adjustable saddle height. The bike also has DT Swiss HX1501 Spline One wheels for reliable performance on any terrain.\n\nCarbon frame\n\nThe full carbon frame is lightweight and durable, providing a smooth and efficient ride. It's also designed with a tapered head tube, internal cable routing, and Boost148 spacing for enhanced stiffness and responsiveness.\n\n## Specs\nFrameset\nFrame\tCarbon main frame & stays, tapered head tube, internal routing, Boost148, 150mm travel\nFork\tRockShox Pike Ultimate, Charger 2.1 RCT3 damper, DebonAir spring, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 160mm travel\nShock\tRockShox Deluxe RT3, DebonAir spring, 205mm x 57.5mm\nMax compatible fork travel\t170mm\n\nWheels\nWheel front\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, 110x15mm Boost\nWheel rear\tDT Swiss HX1501 Spline One, Centerlock, 30mm inner width, SRAM XD driver, 148x12mm Boost\nTire\tBontrager XR4 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.40''\nMax tire size\t29x2.60\"\n\nDrivetrain\nShifter\tSRAM GX Eagle AXS, wireless, 12 speed\nRear derailleur\tSRAM GX Eagle AXS\nCrank\tBosch Gen 4, 32T\nChainring\tSRAM X-Sync 2, 32T, direct-mount\nCassette\tSRAM PG-1275 Eagle, 10-52, 12 speed\nChain\tSRAM GX Eagle, 12 speed\n\nComponents\nSaddle\tBontrager Arvada, hollow titanium rails, 138mm width\nSeatpost\tRockShox Reverb Stealth, 31.6mm, internal routing, 150mm (S), 170mm (M/L), 200mm (XL)\nHandlebar\tBontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\nGrips\tBontrager XR Trail Elite, alloy lock-on\nStem\tBontrager Line Pro, Knock Block, 35mm, 0 degree, 50mm length\nHeadset\tIntegrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake\tSRAM Code RSC hydraulic disc, 200mm (front), 180mm (rear)\nBrake rotor\tSRAM CenterLine, centerlock, round edge, 200mm (front), 180mm (rear)\n\nAccessories\nE-bike system\tBosch Performance Line CX\nBattery\tBosch Powertube 625Wh\nCharger\tBosch 4A compact charger\nController\tBosch Purion\nTool\tBontrager multi-tool, integrated storage bag\n\nWeight\nWeight\tM - 24.08 kg / 53.07 lbs (with TLR sealant, no tubes)\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 153 - 162 cm 5'0\" - 5'4\" | 67 - 74 cm 26\" - 29\" |\n| M | 161 - 172 cm 5'3\" - 5'8\" | 74 - 79 cm 29\" - 31\" |\n| L | 171 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| XL | 179 - 188 cm 5'10\" - 6'2\" | 84 - 89 cm 33\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 41.9 | 44.5 | 47.6 |\n| B — Seat tube angle | 76.1° | 76.1° | 76.1° | 76.1° |\n| C — Head tube length | 9.6 | 10.5 | 11.5 | 12.5 |\n| D — Head angle | 65.5° | 65.5° | 65.5° | 65.5° |\n| E — Effective top tube | 58.6 | 61.3 | 64.0 | 66.7 |\n| F — Bottom bracket height | 34.0 | 34.0 | 34.0 | 34.0 |\n| G — Bottom bracket drop | 1.0 | 1.0 | 1.0 | 1.0 |\n| H — Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\n| I — Offset | 4.6 | 4.6 | 4.6 | 4.6 |\n| J — Trail | 10.5 | 10.5 | 10.5 | 10.5 |\n| K — Wheelbase | 119.5 | 122.3 | 125.0 | 127.8 |\n| L — Standover | 72.7 | 74.7 | 77.6 | 81.0 |\n|", + "price": 1499.99, + "tags": [ + "bicycle", + "electric bike", + "city bike" + ] + }, + { + "name": "Velo 99 XR1 AXS", + "shortDescription": "Velo 99 XR1 AXS is a next-generation bike designed for fast-paced adventure seekers and speed enthusiasts. Built for high-performance racing, the bike boasts state-of-the-art technology and premium components. It is the ultimate bike for riders who want to push their limits and get their adrenaline pumping.", + "description": "## Overview\nIt's right for you if...\nYou are a passionate cyclist looking for a bike that can keep up with your speed, agility, and endurance. You are an adventurer who loves to explore new terrains and challenge yourself on the toughest courses. You want a bike that is lightweight, durable, and packed with the latest technology.\n\nThe tech you get\nA lightweight, full carbon frame with advanced aerodynamics and integrated cable routing for a clean look. A high-performance SRAM XX1 Eagle AXS wireless electronic drivetrain, featuring a 12-speed cassette and a 32T chainring. A RockShox SID Ultimate fork with a remote lockout, 120mm travel, and Charger Race Day damper. A high-end SRAM G2 Ultimate hydraulic disc brake with carbon levers. A FOX Transfer SL dropper post for quick and easy height adjustments. DT Swiss XRC 1501 carbon wheels for superior speed and handling.\n\nThe final word\nVelo 99 XR1 AXS is a premium racing bike that can help you achieve your goals and reach new heights. It is designed for speed, agility, and performance, and it is packed with the latest technology and premium components. If you are a serious cyclist who wants the best, this is the bike for you.\n\n## Features\nAerodynamic design\n\nThe Velo 99 XR1 AXS features a state-of-the-art frame design that reduces drag and improves speed. It has an aerodynamic seatpost, integrated cable routing, and a sleek, streamlined look that sets it apart from other bikes.\n\nWireless electronic drivetrain\n\nThe SRAM XX1 Eagle AXS drivetrain features a wireless electronic system that provides precise, instant shifting and unmatched efficiency. It eliminates the need for cables and makes the bike lighter and faster.\n\nHigh-performance suspension\n\nThe RockShox SID Ultimate fork and Charger Race Day damper provide 120mm of smooth, responsive suspension that can handle any terrain. The fork also has a remote lockout for quick adjustments on the fly.\n\nSuperior braking power\n\nThe SRAM G2 Ultimate hydraulic disc brake system delivers unmatched stopping power and control. It has carbon levers for a lightweight, ergonomic design and precision control.\n\nCarbon wheels\n\nThe DT Swiss XRC 1501 carbon wheels are ultra-lightweight, yet incredibly strong and durable. They provide superior speed and handling, making the bike more agile and responsive.\n\n## Specs\nFrameset\nFrame\tFull carbon frame, integrated cable routing, aerodynamic design, Boost148\nFork\tRockShox SID Ultimate, Charger Race Day damper, remote lockout, tapered steerer, Boost110, 15mm Maxle Stealth, 120mm travel\n\nWheels\nWheel front\tDT Swiss XRC 1501 carbon wheel, Boost110, 15mm thru axle\nWheel rear\tDT Swiss XRC 1501 carbon wheel, SRAM XD driver, Boost148, 12mm thru axle\nTire\tSchwalbe Racing Ray, Performance Line, Addix, 29x2.25\"\nTire part\tSchwalbe Doc Blue Professional, 500ml\nMax tire size\t29x2.3\"\n\nDrivetrain\nShifter\tSRAM Eagle AXS, wireless, 12-speed\nRear derailleur\tSRAM XX1 Eagle AXS\nCrank\tSRAM XX1 Eagle, 32T, carbon\nChainring\tSRAM X-SYNC, 32T, alloy\nCassette\tSRAM Eagle XG-1299, 10-52, 12-speed\nChain\tSRAM XX1 Eagle, 12-speed\nMax chainring size\t1x: 32T\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tFOX Transfer SL, 125mm travel, internal routing, 31.6mm\nHandlebar\tBontrager Kovee Pro, ADV Carbon, 35mm, 5mm rise, 720mm width\nGrips\tBontrager XR Endurance Elite\nStem\tBontrager Kovee Pro, 35mm, Blendr compatible, 7 degree, 60mm length\nHeadset\tIntegrated, cartridge bearing, 1-1/8\" top, 1.5\" bottom\nBrake\tSRAM G2 Ultimate hydraulic disc, carbon levers, 180mm rotors\n\nAccessories\nBike computer\tBontrager Trip 300\nTool\tBontrager Flatline Pro pedal wrench, T25 Torx\n\n\n## Sizing & fit\n\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 158 - 168 cm 5'2\" - 5'6\" | 74 - 78 cm 29\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| L | 173 - 183 cm 5'8\" - 6'0\" | 82 - 86 cm 32\" - 34\" |\n| XL | 180 - 193 cm 5'11\" - 6'4\" | 86 - 90 cm 34\" - 35\" |\n\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 15.5 | 17.5 | 19.5 | 21.5 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.9 | 43.0 | 47.0 | 51.0 |\n| B — Seat tube angle | 74.5° | 74.5° | 74.5° | 74.5° |\n| C — Head tube length | 9.0 | 10.0 | 11.0 | 12.0 |\n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° |\n| E — Effective top tube | 57.8 | 59.7 | 61.6 | 63.6 |\n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 9.7 | 9.7 | 9.7 | 9.7 |\n| K — Wheelbase | 112.5 | 114.5 | 116.5 | 118.6 |\n| L — Standover | 75.9 | 77.8 | 81.5 | 84.2 |\n| M — Frame reach | 41.6 | 43.4 | 45.2 | 47.1 |\n| N — Frame stack | 58.2 | 58.9 | 59.3 | 59.9 |", + "price": 1099.99, + "tags": [ + "bicycle", + "mountain bike" + ] + }, + { + "name": "AURORA 11S E-MTB", + "shortDescription": "The AURORA 11S is a powerful and stylish electric mountain bike designed to take you on thrilling off-road adventures. With its sturdy frame and premium components, this bike is built to handle any terrain. It features a high-performance motor, long-lasting battery, and advanced suspension system that guarantee a smooth and comfortable ride.", + "description": "## Overview\nIt's right for you if...\nYou want a top-of-the-line e-MTB that is both powerful and stylish. You also want a bike that can handle any terrain, from steep climbs to rocky descents. With its advanced features and premium components, the AURORA 11S is designed for serious off-road riders who demand the best.\n\nThe tech you get\nA sturdy aluminum frame with advanced suspension system that provides 120mm of travel. A 750W brushless motor that delivers up to 28mph, and a 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge. An advanced 11-speed Shimano drivetrain with hydraulic disc brakes for precise shifting and reliable stopping power. \n\nThe final word\nThe AURORA 11S is a top-of-the-line e-MTB that delivers exceptional performance and style. Whether you're tackling steep climbs or hitting rocky descents, this bike is built to handle any terrain with ease. With its advanced features and premium components, the AURORA 11S is the perfect choice for serious off-road riders who demand the best.\n\n## Features\nPowerful and efficient\n\nThe AURORA 11S is equipped with a high-performance 750W brushless motor that delivers up to 28mph. The motor is powered by a long-lasting 48V/14Ah lithium-ion battery that provides up to 60 miles of range on a single charge.\n\nAdvanced suspension system\n\nThe bike's advanced suspension system provides 120mm of travel, ensuring a smooth and comfortable ride on any terrain. The front suspension is a Suntour XCR32 Air fork, while the rear suspension is a KS-281 hydraulic shock absorber.\n\nPremium components\n\nThe AURORA 11S features an advanced 11-speed Shimano drivetrain with hydraulic disc brakes. The bike is also equipped with a Tektro HD-E725 hydraulic disc brake system that provides reliable stopping power.\n\nSleek and stylish design\n\nWith its sleek and stylish design, the AURORA 11S is sure to turn heads on the trail. The bike's sturdy aluminum frame is available in a range of colors, including black, blue, and red.\n\n## Specs\nFrameset\nFrame Material: Aluminum\nFrame Size: S, M, L\nFork: Suntour XCR32 Air, 120mm Travel\nShock Absorber: KS-281 Hydraulic Shock Absorber\n\nWheels\nWheel Size: 27.5 inches\nTires: Kenda K1151 Nevegal, 27.5x2.35\nRims: Alloy Double Wall\nSpokes: 32H, Stainless Steel\n\nDrivetrain\nShifters: Shimano SL-M7000\nRear Derailleur: Shimano RD-M8000\nCrankset: Prowheel 42T, Alloy Crank Arm\nCassette: Shimano CS-M7000, 11-42T\nChain: KMC X11EPT\n\nBrakes\nBrake System: Tektro HD-E725 Hydraulic Disc Brake\nBrake Rotors: 180mm Front, 160mm Rear\n\nE-bike system\nMotor: 750W Brushless\nBattery: 48V/14Ah Lithium-Ion\nCharger: 48V/3A Smart Charger\nController: Intelligent Sinusoidal Wave\n\nWeight\nWeight: 59.5 lbs\n\n## Sizing & fit\n| Size | Rider Height | Standover Height |\n|------|-------------|-----------------|\n| S | 5'2\"-5'6\" | 28.5\" |\n| M | 5'7\"-6'0\" | 29.5\" |\n| L | 6'0\"-6'4\" | 30.5\" |\n\n## Geometry\nAll measurements provided in cm.\nSizing table\n| Frame size letter | S | M | L |\n|-------------------|-----|-----|-----|\n| Wheel Size | 27.5\"| 27.5\"| 27.5\"|\n| Seat tube length | 44.5| 48.5| 52.5|\n| Head tube angle | 68° | 68° | 68° |\n| Seat tube angle | 74.5°| 74.5°| 74.5°|\n| Effective top tube | 57.5| 59.5| 61.5|\n| Head tube length | 12.0| 12.0| 13.0|\n| Chainstay length | 45.5| 45.5| 45.5|\n| Bottom bracket height | 30.0| 30.0| 30.0|\n| Wheelbase | 115.0|116.5|118.5|", + "price": 1999.99, + "tags": [ + "bicycle", + "road bike" + ] + }, + { + "name": "VeloTech V9.5 AXS Gen 3", + "shortDescription": "VeloTech V9.5 AXS is a sleek and fast carbon bike that combines high-end tech with a comfortable ride. It's designed to provide the ultimate experience for the most serious riders. The bike comes with a lightweight and powerful motor that can be activated when needed, and you get a spec filled with premium parts.", + "description": "## Overview\nIt's right for you if...\nYou want a bike that is fast, efficient, and delivers an adrenaline-filled experience. You are looking for a bike that is built with cutting-edge technology, and you want a ride that is both comfortable and exciting.\n\nThe tech you get\nA lightweight and durable full carbon frame with a fork that has 100mm of travel. The bike comes with a powerful motor that can deliver up to 20 mph of assistance. The drivetrain is a wireless electronic system that is precise and reliable. The bike is also equipped with hydraulic disc brakes, tubeless-ready wheels, and comfortable grips.\n\nThe final word\nThe VeloTech V9.5 AXS is a high-end bike that delivers an incredible experience for serious riders. It combines the latest technology with a comfortable ride, making it perfect for long rides, tough climbs, and fast descents.\n\n## Features\nFast and efficient\nThe VeloTech V9.5 AXS comes with a powerful motor that can provide up to 20 mph of assistance. The motor is lightweight and efficient, providing a boost when you need it without adding bulk. The bike's battery is removable, allowing you to ride without assistance when you don't need it.\n\nSmart software for the trail\nThe VeloTech V9.5 AXS is equipped with intelligent software that delivers a smooth and responsive ride. The software allows the motor to respond immediately as you start to pedal, delivering more power over a wider cadence range. You can also customize your user settings to suit your preferences.\n\nComfortable ride\nThe VeloTech V9.5 AXS is designed to provide a comfortable ride, even on long rides. The bike's fork has 100mm of travel, providing ample cushioning for rough terrain. The bike's grips are also designed to provide a comfortable and secure grip, even on the most challenging rides.\n\n## Specs\nFrameset\nFrame\tCarbon fiber frame with internal cable routing and Boost148\nFork\t100mm of travel with remote lockout\nShock\tN/A\n\nWheels\nWheel front\tCarbon fiber tubeless-ready wheel\nWheel rear\tCarbon fiber tubeless-ready wheel\nSkewer rear\t12mm thru-axle\nTire\tTubeless-ready tire\nTire part\tTubeless sealant\n\nDrivetrain\nShifter\tWireless electronic shifter\nRear derailleur\tWireless electronic derailleur\nCrank\tCarbon fiber crankset with chainring\nCrank arm\tCarbon fiber crank arm\nChainring\tAlloy chainring\nCassette\t12-speed cassette\nChain\t12-speed chain\n\nComponents\nSaddle\tCarbon fiber saddle\nSeatpost\tCarbon fiber seatpost\nHandlebar\tCarbon fiber handlebar\nGrips\tComfortable and secure grips\nStem\tCarbon fiber stem\nHeadset\tCarbon fiber headset\nBrake\tHydraulic disc brakes\nBrake rotor\tDisc brake rotor\n\nAccessories\nE-bike system\tPowerful motor with removable battery\nBattery\tLithium-ion battery\nCharger\tFast charging adapter\nController\tHandlebar-mounted controller\nTool\tBasic toolkit\n\nWeight\nWeight\tM - 17.5 kg / 38.5 lbs (with tubeless sealant)\n\nWeight limit\nThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing & fit\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| S | 160 - 170 cm 5'3\" - 5'7\" | 74 - 79 cm 29\" - 31\" |\n| M | 170 - 180 cm 5'7\" - 5'11\" | 79 - 84 cm 31\" - 33\" |\n| L | 180 - 190 cm 5'11\" - 6'3\" | 84 - 89 cm 33\" - 35\" |\n| XL | 190 - 200 cm 6'3\" - 6'7\" | 89 - 94 cm 35\" - 37\" |\n\n## Geometry\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Actual frame size | 50.0 | 53.3 | 55.6 | 58.8 |\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\n| A — Seat tube | 39.4 | 43.2 | 48.3 | 53.3 |\n| B — Seat tube angle | 72.3° | 72.6° | 72.8° | 72.8° |\n| C — Head tube length | 9.0 | 10.0 | 10.5 | 11.0 |\n| D — Head angle | 67.5° | 67.5° | 67.5° | 67.5° |\n| E — Effective top tube | 58.0 | 61.7 | 64.8 | 67.0 |\n| F — Bottom bracket height | 32.3 | 32.3 | 32.3 | 32.3 |\n| G — Bottom bracket drop | 5.0 | 5.0 | 5.0 | 5.0 |\n| H — Chainstay length | 44.7 | 44.7 | 44.7 | 44.7 |\n| I — Offset | 4.2 | 4.2 | 4.2 | 4.2 |\n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 |\n| K — Wheelbase | 112.6 | 116.5 | 119.7 | 121.9 |\n| L — Standover | 76.8 | 76.8 | 76.8 | 76.8 |\n| M — Frame reach | 40.5 | 44.0 | 47.0 | 49.0 |\n| N — Frame stack | 60.9 | 61.8 | 62.2 | 62.7 |", + "price": 1699.99, + "tags": [ + "bicycle", + "electric bike", + "city bike" + ] + }, + { + "name": "Axiom D8 E-Mountain Bike", + "shortDescription": "The Axiom D8 is an electrifying mountain bike that is built for adventure. It boasts a light aluminum frame, a powerful motor and the latest tech to tackle the toughest of terrains. The D8 provides assistance without adding bulk to the bike, giving you the flexibility to ride like a traditional mountain bike or have an extra push when you need it.", + "description": "## Overview \nIt's right for you if... \nYou're looking for an electric mountain bike that can handle a wide variety of terrain, from flowing singletrack to technical descents. You also want a bike that offers a powerful motor that provides assistance without adding bulk to the bike. The D8 is designed to take you anywhere, quickly and comfortably.\n\nThe tech you get \nA lightweight aluminum frame with 140mm of travel, a Suntour fork with hydraulic lockout, and a reliable and powerful Bafang M400 mid-motor that provides a boost up to 20 mph. The bike features a Shimano Deore drivetrain, hydraulic disc brakes, and a dropper seat post. With the latest tech on-board, the D8 is designed to take you to new heights.\n\nThe final word \nThe Axiom D8 is an outstanding electric mountain bike that is designed for adventure. It's built with the latest tech and provides the flexibility to ride like a traditional mountain bike or have an extra push when you need it. Whether you're a beginner or an experienced rider, the D8 is the perfect companion for your next adventure.\n\n## Features \nBuilt for Adventure \n\nThe D8 features a lightweight aluminum frame that is built to withstand rugged terrain. It comes equipped with 140mm of travel and a Suntour fork that can handle even the toughest of trails. With this bike, you're ready to take on anything the mountain can throw at you.\n\nPowerful Motor \n\nThe Bafang M400 mid-motor provides reliable and powerful assistance without adding bulk to the bike. You can quickly and easily switch between the different assistance levels to find the perfect balance between range and power.\n\nShimano Deore Drivetrain \n\nThe Shimano Deore drivetrain is reliable and offers smooth shifting on any terrain. You can easily adjust the gears to match your riding style and maximize your performance on the mountain.\n\nDropper Seat Post \n\nThe dropper seat post allows you to easily adjust your seat height on the fly, so you can maintain the perfect position for any terrain. With the flick of a switch, you can quickly and easily lower or raise your seat to match the terrain.\n\nHydraulic Disc Brakes \n\nThe D8 features powerful hydraulic disc brakes that offer reliable stopping power in any weather condition. You can ride with confidence knowing that you have the brakes to stop on a dime.\n\n## Specs \nFrameset \nFrame\tAluminum frame with 140mm of travel \nFork\tSuntour fork with hydraulic lockout, 140mm of travel \nShock\tN/A \nMax compatible fork travel\t140mm \n \nWheels \nWheel front\tAlloy wheel \nWheel rear\tAlloy wheel \nSkewer rear\tThru axle \nTire\t29\" x 2.35\" \nTire part\tN/A \nMax tire size\t29\" x 2.6\" \n \nDrivetrain \nShifter\tShimano Deore \nRear derailleur\tShimano Deore \nCrank\tBafang M400 \nCrank arm\tN/A \nChainring\tN/A \nCassette\tShimano Deore \nChain\tShimano Deore \nMax chainring size\tN/A \n \nComponents \nSaddle\tAxiom D8 saddle \nSeatpost\tDropper seat post \nHandlebar\tAxiom D8 handlebar \nGrips\tAxiom D8 grips \nStem\tAxiom D8 stem \nHeadset\tAxiom D8 headset \nBrake\tHydraulic disc brakes \nBrake rotor\t180mm \n\nAccessories \nE-bike system\tBafang M400 mid-motor \nBattery\tLithium-ion battery, 500Wh \nCharger\tLithium-ion charger \nController\tBafang M400 controller \nTool\tN/A \n \nWeight \nWeight\tM - 22 kg / 48.5 lbs \nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 136 kg (300 lbs). \n \n \n## Sizing & fit \n \n| Size | Rider Height | Inseam | \n|:----:|:------------------------:|:--------------------:| \n| S | 152 - 165 cm 5'0\" - 5'5\" | 70 - 76 cm 27\" - 30\" | \n| M | 165 - 178 cm 5'5\" - 5'10\" | 76 - 81 cm 30\" - 32\" | \n| L | 178 - 185 cm 5'10\" - 6'1\" | 81 - 86 cm 32\" - 34\" | \n| XL | 185 - 193 cm 6'1\" - 6'4\" | 86 - 91 cm 34\" - 36\" | \n \n \n## Geometry \n \nAll measurements provided in cm unless otherwise noted. \nSizing table \n| Frame size letter | S | M | L | XL | \n|---------------------------|-------|-------|-------|-------| \n| Actual frame size | 41.9 | 46.5 | 50.8 | 55.9 | \n| Wheel size | 29\" | 29\" | 29\" | 29\" | \n| A — Seat tube | 42.0 | 46.5 | 51.0 | 56.0 | \n| B — Seat tube angle | 74.0° | 74.0° | 74.0° | 74.0° | \n| C — Head tube length | 11.0 | 12.0 | 13.0 | 15.0 | \n| D — Head angle | 68.0° | 68.0° | 68.0° | 68.0° | \n| E — Effective top tube | 57.0 | 60.0 | 62.0 | 65.0 | \n| F — Bottom bracket height | 33.0 | 33.0 | 33.0 | 33.0 | \n| G — Bottom bracket drop | 3.0 | 3.0 | 3.0 | 3.0 | \n| H — Chainstay length | 46.0 | 46.0 | 46.0 | 46.0 | \n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | \n| J — Trail | 10.9 | 10.9 | 10.9 | 10.9 | \n| K — Wheelbase | 113.0 | 116.0 | 117.5 | 120.5 | \n| L — Standover | 73.5 | 75.5 | 76.5 | 79.5 | \n| M — Frame reach | 41.0 | 43.5 | 45.0 | 47.5 | \n| N — Frame stack | 60.5 | 61.5 | 62.5 | 64.5 |", + "price": 1399.99, + "tags": [ + "bicycle", + "electric bike", + "mountain bike" + ] + }, + { + "name": "Velocity X1", + "shortDescription": "Velocity X1 is a high-performance road bike designed for speed enthusiasts. It features a lightweight yet durable frame, aerodynamic design, and top-quality components, making it the perfect choice for those who want to take their cycling experience to the next level.", + "description": "## Overview\nIt's right for you if...\nYou're an experienced cyclist looking for a bike that can keep up with your need for speed. You want a bike that's lightweight, aerodynamic, and built to perform, whether you're training for a race or just pushing yourself to go faster.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork, Shimano Ultegra groupset with a wide range of gearing, hydraulic disc brakes, aerodynamic carbon wheels, and a vibration-absorbing handlebar with ergonomic grips.\n\nThe final word\nVelocity X1 is the ultimate road bike for speed enthusiasts. Its lightweight frame, aerodynamic design, and top-quality components make it the perfect choice for those who want to take their cycling experience to the next level.\n\n\n## Features\n\nAerodynamic design\nVelocity X1 is built with an aerodynamic design to help you go faster with less effort. It features a sleek profile, hidden cables, and a carbon fork that cuts through the wind, reducing drag and increasing speed.\n\nHydraulic disc brakes\nVelocity X1 comes equipped with hydraulic disc brakes, providing excellent stopping power in all weather conditions. They're also low maintenance, with minimal adjustments needed over time.\n\nCarbon wheels\nThe Velocity X1's aerodynamic carbon wheels provide excellent speed and responsiveness, helping you achieve your fastest times yet. They're also lightweight, reducing overall bike weight and making acceleration and handling even easier.\n\nShimano Ultegra groupset\nThe Shimano Ultegra groupset provides smooth shifting and reliable performance, ensuring you get the most out of every ride. With a wide range of gearing options, it's ideal for tackling any terrain, from steep climbs to fast descents.\n\n\n## Specifications\nFrameset\nFrame with Fork\tAluminium frame, internal cable routing, 135x9mm QR\nFork\tCarbon, hidden cable routing, 100x9mm QR\n\nWheels\nWheel front\tCarbon, 30mm deep rim, 23mm width, 100x9mm QR\nWheel rear\tCarbon, 30mm deep rim, 23mm width, 135x9mm QR\nSkewer front\t100x9mm QR\nSkewer rear\t135x9mm QR\nTire\tContinental Grand Prix 5000, 700x25mm, folding bead\nMax tire size\t700x28mm without fenders\n\nDrivetrain\nShifter\tShimano Ultegra R8020, 11 speed\nRear derailleur\tShimano Ultegra R8000, 11 speed\n*Crank\tSize: S, M\nShimano Ultegra R8000, 50/34T, 170mm length\nSize: L, XL\nShimano Ultegra R8000, 50/34T, 175mm length\nBottom bracket\tShimano BB-RS500-PB, PressFit\nCassette\tShimano Ultegra R8000, 11-30T, 11 speed\nChain\tShimano Ultegra HG701, 11 speed\nPedal\tNot included\nMax chainring size\t50/34T\n\nComponents\nSaddle\tBontrager Montrose Comp, steel rails, 138mm width\nSeatpost\tBontrager Comp, 6061 alloy, 27.2mm, 8mm offset, 330mm length\n*Handlebar\tSize: S, M, L\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 400mm width\nSize: XL\nBontrager Elite Aero VR-CF, alloy, 31.8mm, 93mm reach, 123mm drop, 420mm width\nGrips\tBontrager Supertack Perf tape\n*Stem\tSize: S, M, L\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 90mm length\nSize: XL\nBontrager Elite Blendr, 31.8mm clamp, 7 degree, 100mm length\nBrake\tShimano Ultegra R8070 hydraulic disc, flat mount\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.15 kg / 17.97 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 162 - 170 cm 5'4\" - 5'7\" | 74 - 78 cm 29\" - 31\" |\n| M | 170 - 178 cm 5'7\" - 5'10\" | 77 - 82 cm 30\" - 32\" |\n| L | 178 - 186 cm 5'10\" - 6'1\" | 82 - 86 cm 32\" - 34\" |\n| XL | 186 - 196 cm 6'1\" - 6'5\" | 87 - 92 cm 34\" - 36\" |\n\n\n## Geometry\n| Frame size letter | S | M | L | XL |\n|---------------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.0 | 52.0 | 54.0 | 56.0 |\n| B — Seat tube angle | 74.0° | 73.5° | 73.0° | 72.5° |\n| C — Head tube length | 13.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 71.0° | 72.0° | 72.0° | 72.5° |\n| E — Effective top tube | 53.7 | 55.0 | 56.5 | 58.0 |\n| F — Bottom bracket height | 27.5 | 27.5 | 27.5 | 27.5 |\n| G — Bottom bracket drop | 7.3 | 7.3 | 7.3 | 7.3 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 5.8 |\n| K — Wheelbase | 98.2 | 99.1 | 100.1 | 101.0 |\n| L — Standover | 75.2 | 78.2 | 81.1 | 84.1 |\n| M — Frame reach | 37.5 | 38.3 | 39.1 | 39.9 |\n| N — Frame stack | 53.3 | 55.4 | 57.4 | 59.5 |", + "price": 1799.99, + "tags": [ + "bicycle", + "touring bike" + ] + }, + { + "name": "Velocity V9", + "shortDescription": "Velocity V9 is a high-performance hybrid bike that combines speed and comfort for riders who demand the best of both worlds. The lightweight aluminum frame, along with the carbon fork and seat post, provide optimal stiffness and absorption to tackle any terrain. A 2x Shimano Deore drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires make it a versatile ride for commuters, fitness riders, and weekend adventurers alike.", + "description": "## Overview\nIt's right for you if...\nYou want a fast, versatile bike that can handle anything from commuting to weekend adventures. You value comfort as much as speed and performance. You want a reliable and durable bike that will last for years to come.\n\nThe tech you get\nA lightweight aluminum frame with a carbon fork and seat post, a 2x Shimano Deore drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. The Velocity V9 is designed for riders who demand both performance and comfort in one package.\n\nThe final word\nThe Velocity V9 is the perfect bike for riders who want speed and performance without sacrificing comfort. The lightweight aluminum frame and carbon components provide optimal stiffness and absorption, while the 2x Shimano Deore drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're commuting, hitting the trails, or training for your next race, the Velocity V9 has everything you need to achieve your goals.\n\n## Features\n\n2x drivetrain\nA 2x drivetrain means more versatility and a wider range of gearing options. Whether you're climbing hills or sprinting on the flats, the Velocity V9 has the perfect gear for any situation.\n\nCarbon components\nThe Velocity V9 features a carbon fork and seat post to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unparalleled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\n## Specifications\nFrameset\nFrame with Fork\tAluminum frame with carbon fork and seat post, internal cable routing, fender mounts, 135x5mm ThruSkew\nFork\tCarbon fork, hidden fender mounts, flat mount disc, 5x100mm thru-skew\n\nWheels\nWheel front\tDouble wall aluminum rims, 700c, quick release hub\nWheel rear\tDouble wall aluminum rims, 700c, quick release hub\nTire\tKenda Kwick Tendril, puncture resistant, reflective sidewall, 700x32c\nMax tire size\t700x35c without fenders, 700x32c with fenders\n\nDrivetrain\nShifter\tShimano Deore, 10 speed\nFront derailleur\tShimano Deore\nRear derailleur\tShimano Deore\nCrank\tShimano Deore, 46-30T, 170mm (S/M), 175mm (L/XL)\nBottom bracket\tShimano BB52, 68mm, threaded\nCassette\tShimano Deore, 11-36T, 10 speed\nChain\tShimano HG54, 10 speed\nPedal\tWellgo alloy platform\n\nComponents\nSaddle\tVelo VL-2158, steel rails\nSeatpost\tCarbon seat post, 27.2mm\nHandlebar\tAluminum, 31.8mm clamp, 15mm rise, 680mm width\nGrips\tVelo ergonomic grips\nStem\tAluminum, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, MT200 lever, MT200 caliper\nBrake rotor\tShimano RT56, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 11.5 kg / 25.35 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 44.0 | 48.0 | 52.0 | 56.0 |\n| B — Seat tube angle | 74.5° | 74.0° | 73.5° | 73.0° |\n| C — Head tube length | 14.5 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 71.0° | 71.0° | 71.5° | 71.5° |\n| E — Effective top tube | 56.5 | 57.5 | 58.5 | 59.5 |\n| F — Bottom bracket height | 27.0 | 27.0 | 27.0 | 27.0 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 43.0 | 43.0 | 43.0 | 43.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 7.0 | 7.0 | 6.6 | 6.6 |\n| K — Wheelbase | 105.4 | 106.3 | 107.2 | 108.2 |\n| L — Standover | 73.2 | 77.1 | 81.2 | 85.1 |\n| M — Frame reach | 39.0 | 39.8 | 40.4 | 41.3 |\n| N — Frame stack | 57.0 | 58.5 | 60.0 | 61.5 |", + "price": 2199.99, + "tags": [ + "bicycle", + "electric bike", + "mountain bike" + ] + }, + { + "name": "Aero Pro X", + "shortDescription": "Aero Pro X is a high-end racing bike designed for serious cyclists who demand speed, agility, and superior performance. The lightweight carbon frame and fork, combined with the aerodynamic design, provide optimal stiffness and efficiency to maximize your speed. The bike features a 2x Shimano Ultegra drivetrain, hydraulic disc brakes, and 700c wheels with high-quality tires. Whether you're competing in a triathlon or climbing steep hills, Aero Pro X delivers exceptional performance and precision handling.", + "description": "## Overview\nIt's right for you if...\nYou are a competitive cyclist looking for a bike that is designed for racing. You want a bike that delivers exceptional speed, agility, and precision handling. You demand superior performance and reliability from your equipment.\n\nThe tech you get\nA lightweight carbon frame with an aerodynamic design, a carbon fork with hidden fender mounts, a 2x Shimano Ultegra drivetrain with a wide range of gearing, hydraulic disc brakes, and 700c wheels with high-quality tires. Aero Pro X is designed for serious cyclists who demand nothing but the best.\n\nThe final word\nAero Pro X is the ultimate racing bike for serious cyclists. The lightweight carbon frame and aerodynamic design deliver maximum speed and efficiency, while the 2x Shimano Ultegra drivetrain and hydraulic disc brakes ensure precise shifting and stopping power. Whether you're competing in a triathlon or a criterium race, Aero Pro X delivers the performance you need to win.\n\n## Features\n\nAerodynamic design\nThe Aero Pro X features an aerodynamic design that reduces drag and maximizes efficiency. The bike is optimized for speed and agility, so you can ride faster and farther with less effort.\n\nHydraulic disc brakes\nHydraulic disc brakes provide unrivaled stopping power and modulation in any weather condition. You'll feel confident and in control no matter where you ride.\n\nCarbon components\nThe Aero Pro X features a carbon fork with hidden fender mounts to provide optimal stiffness and absorption. This means you can ride faster and more comfortably over any terrain.\n\n## Specifications\nFrameset\nFrame with Fork\tCarbon frame with an aerodynamic design, internal cable routing, 3s chain keeper, 142x12mm thru-axle\nFork\tCarbon fork with hidden fender mounts, flat mount disc, 100x12mm thru-axle\n\nWheels\nWheel front\tDouble wall carbon rims, 700c, thru-axle hub\nWheel rear\tDouble wall carbon rims, 700c, thru-axle hub\nTire\tContinental Grand Prix 5000, folding bead, 700x25c\nMax tire size\t700x28c without fenders, 700x25c with fenders\n\nDrivetrain\nShifter\tShimano Ultegra, 11 speed\nFront derailleur\tShimano Ultegra\nRear derailleur\tShimano Ultegra\nCrank\tShimano Ultegra, 52-36T, 170mm (S), 172.5mm (M), 175mm (L/XL)\nBottom bracket\tShimano BB72, 68mm, PressFit\nCassette\tShimano Ultegra, 11-30T, 11 speed\nChain\tShimano HG701, 11 speed\nPedal\tNot included\n\nComponents\nSaddle\tBontrager Montrose Elite, carbon rails, 138mm width\nSeatpost\tCarbon seat post, 27.2mm, 20mm offset\nHandlebar\tBontrager XXX Aero, carbon, 31.8mm clamp, 75mm reach, 125mm drop\nGrips\tBontrager Supertack Perf tape\nStem\tBontrager Pro, 31.8mm clamp, 7 degree, 90mm length\nBrake\tShimano hydraulic disc, Ultegra lever, Ultegra caliper\nBrake rotor\tShimano RT800, centerlock, 160mm\nRotor size\tMax brake rotor sizes: 160mm front & rear\n\nWeight\nWeight\tM - 8.36 kg / 18.42 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height |\n|:----:|:-------------------------:|\n| S | 155 - 165 cm 5'1\" - 5'5\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" |\n\n## Geometry\n| Frame size | S | M | L | XL |\n|--------------------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 50.6 | 52.4 | 54.3 | 56.2 |\n| B — Seat tube angle | 75.5° | 74.5° | 73.5° | 72.5° |\n| C — Head tube length | 12.0 | 14.0 | 16.0 | 18.0 |\n| D — Head angle | 72.5° | 73.0° | 73.5° | 74.0° |\n| E — Effective top tube | 53.8 | 55.4 | 57.0 | 58.6 |\n| F — Bottom bracket height | 26.5 | 26.5 | 26.5 | 26.5 |\n| G — Bottom bracket drop | 7.0 | 7.0 | 7.0 | 7.0 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 6.0 | 6.0 | 6.0 | 6.0 |\n| K — Wheelbase | 97.1 | 98.7 | 100.2 | 101.8 |\n| L — Standover | 73.8 | 76.2 | 78.5 | 80.8 |\n| M — Frame reach | 38.8 | 39.5 | 40.2 | 40.9 |\n| N — Frame stack | 52.8 | 54.7 | 56.6 | 58.5 |", + "price": 1599.99, + "tags": [ + "bicycle", + "road bike" + ] + }, + { + "name": "Voltex+ Ultra Lowstep", + "shortDescription": "Voltex+ Ultra Lowstep is a high-performance electric hybrid bike designed for riders who seek speed, comfort, and reliability during their everyday rides. Equipped with a powerful and efficient Voltex Drive Pro motor and a fully-integrated 600Wh battery, this e-bike allows you to cover longer distances on a single charge. The Voltex+ Ultra Lowstep comes with premium components that prioritize comfort and safety, such as a suspension seatpost, wide and stable tires, and integrated lights.", + "description": "## Overview\n\nIt's right for you if...\nYou want an e-bike that provides a boost for faster rides and effortless usage. Durability is crucial, and you need a bike with one of the most powerful and efficient motors.\n\nThe tech you get\nA lightweight Delta Carbon Fiber frame with an ultra-lowstep design, a Voltex Drive Pro (350W, 75Nm) motor capable of maintaining speeds up to 30 mph, an extended range 600Wh battery integrated into the frame, and a Voltex Control Panel. Additionally, it features a 12-speed Shimano drivetrain, hydraulic disc brakes for optimal all-weather stopping power, a suspension seatpost, wide puncture-resistant tires for added stability, ergonomic grips, a kickstand, lights, and a cargo rack.\n\nThe final word\nThis bike offers enhanced enjoyment and ease of use on long commutes, leisure rides, and adventures. With its extended-range battery, powerful Voltex motor, user-friendly controller, and a seatpost that smooths out road vibrations, it guarantees an exceptional riding experience.\n\n## Features\n\nUltra-fast assistance\n\nExperience speeds up to 30 mph with the cutting-edge Voltex Drive Pro motor, allowing you to breeze through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\n- Frame: Delta Carbon Fiber, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Voltex Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: Voltex Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: Voltex E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore XT M8100, 12-speed\n- Rear derailleur: Shimano Deore XT M8100, long cage\n- Crank: Voltex alloy, 170mm length\n- Chainring: FSA, 44T, aluminum with guard\n- Cassette: Shimano Deore XT M8100, 10-51, 12-speed\n- Chain: KMC E12 Turbo\n- Pedal: Voltex Urban pedals\n\nComponents\n- Saddle: Voltex Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar: Voltex alloy, 31.8mm, comfort sweep, 620mm width (XS, S, M), 660mm width (L)\n- Grips: Voltex Satellite Elite, alloy lock-on\n- Stem: Voltex alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length (XS, S), 105mm length (M, L)\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT520 hydraulic disc\n- Brake rotor: Shimano RT56, 6-bolt, 180mm (XS, S, M, L), 160mm (XS, S, M, L)\n\nAccessories\n- Battery: Voltex PowerTube 600Wh\n- Charger: Voltex compact 2A, 100-240V\n- Computer: Voltex Control Panel\n- Motor: Voltex Drive Pro, 75Nm, 30mph\n- Light: Voltex Solo for e-bike, taillight (XS, S, M, L), Voltex MR8, 180 lumen, 60 lux, LED, headlight (XS, S, M, L)\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: Voltex-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender: Voltex wide (XS, S, M, L), Voltex plastic (XS, S, M, L)\n\nWeight\n- Weight: M - 20.50 kg / 45.19 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 330 pounds (150 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 38.0 | 43.0 | 48.0 | 53.0 |\n| B — Seat tube angle | 70.5° | 70.5° | 70.5° | 70.5° |\n| C — Head tube length | 15.0 | 15.0 | 17.0 | 19.0 |\n| D — Head angle | 69.2° | 69.2° | 69.2° | 69.2° |\n| E — Effective top tube | 57.2 | 57.7 | 58.8 | 60.0 |\n| F — Bottom bracket height | 30.3 | 30.3 | 30.3 | 30.3 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.5 | 48.5 | 48.5 | 48.5 |\n| I — Offset | 5.0 | 5.0 | 5.0 | 5.0 |\n| J — Trail | 9.0 | 9.0 | 9.0 | 9.0 |\n| K — Wheelbase | 111.8 | 112.3 | 113.6 | 114.8 |\n| L — Standover | 42.3 | 42.3 | 42.3 | 42.3 |\n| M — Frame reach | 36.0 | 38.0 | 38.0 | 38.0 |\n| N — Frame stack | 62.0 | 62.0 | 63.9 | 65.8 |\n| Stem length | 8.0 | 8.5 | 8.5 | 10.5 |\n\nPlease note that the specifications and features listed above are subject to change and may vary based on different models and versions of the Voltex+ Ultra Lowstep bike.", + "price": 2999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "SwiftRide Hybrid", + "shortDescription": "SwiftRide Hybrid is a versatile and efficient bike designed for riders who want a smooth and enjoyable ride on various terrains. It incorporates advanced technology and high-quality components to provide a comfortable and reliable cycling experience.", + "description": "## Overview\n\nIt's right for you if...\nYou are looking for a bike that combines the benefits of an electric bike with the versatility of a hybrid. You value durability, speed, and ease of use.\n\nThe tech you get\nThe SwiftRide Hybrid features a lightweight and durable aluminum frame, making it easy to handle and maneuver. It is equipped with a powerful electric motor that offers a speedy assist, helping you reach speeds of up to 25 mph. The bike comes with a removable and fully-integrated 500Wh battery, providing a long-range capacity for extended rides. It also includes a 10-speed Shimano drivetrain, hydraulic disc brakes for precise stopping power, wide puncture-resistant tires for stability, and integrated lights for enhanced visibility.\n\nThe final word\nThe SwiftRide Hybrid is designed for riders who want a bike that can handle daily commutes, recreational rides, and adventures. With its efficient motor, intuitive controls, and comfortable features, it offers an enjoyable and hassle-free riding experience.\n\n## Features\n\nEfficient electric assist\nExperience the thrill of effortless riding with the powerful electric motor that provides a speedy assist, making your everyday rides faster and more enjoyable.\n\n## Specs\n\nFrameset\n- Frame: Lightweight Aluminum, Removable Integrated Battery (RIB), rack & fender mounts, internal routing, 135x5mm QR\n- Fork: SwiftRide Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: SwiftRide Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: SwiftRide E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: SwiftRide City pedals\n\nComponents\n- Saddle: SwiftRide Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - SwiftRide alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - SwiftRide alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: SwiftRide Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 85mm length\n - Size: M, L - SwiftRide alloy quill, 31.8mm clamp, adjustable rise, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: SwiftRide PowerTube 500Wh\n- Charger: SwiftRide compact 2A, 100-240V\n- Computer: SwiftRide Purion\n- Motor: SwiftRide Performance Line Sport, 65Nm, 25mph\n- Light:\n - Size: XS, S, M, L - SwiftRide SOLO for e-bike, taillight\n - Size: XS, S, M, L - SwiftRide MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: SwiftRide-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SwiftRide wide\n - Size: XS, S, M, L - SwiftRide plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm (4'10\" - 5'1\") | 69 - 73 cm (27\" - 29\") |\n| S | 155 - 165 cm (5'1\" - 5'5\") | 72 - 78 cm (28\" - 31\") |\n| M | 165 - 175 cm (5'5\" - 5'9\") | 77 - 83 cm (30\" - 33\") |\n| L | 175 - 186 cm (5'9\" - 6'1\") | 82 - 88 cm (32\" - 35\") |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 3999.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "RoadRunner E-Speed Lowstep", + "shortDescription": "RoadRunner E-Speed Lowstep is a high-performance electric hybrid designed for riders seeking speed and excitement on their daily rides. It is equipped with a powerful and reliable ThunderBolt drive unit that offers exceptional acceleration. The bike features a fully-integrated 500Wh battery, allowing riders to cover longer distances on a single charge. With its comfortable and safe components, including a suspension seatpost, wide and stable tires, and integrated lights, the RoadRunner E-Speed Lowstep ensures a smooth and enjoyable ride.", + "description": "## Overview\n\nIt's right for you if...\nYou're looking for an e-bike that provides an extra boost to reach your destination quickly and effortlessly. You prioritize durability and want a bike with one of the fastest motors available.\n\nThe tech you get\nA lightweight and sturdy ThunderBolt aluminum frame with a lowstep geometry. The bike is equipped with a ThunderBolt Performance Sport (250W, 65Nm) drive unit capable of reaching speeds up to 28 mph. It features a long-range 500Wh battery fully integrated into the frame and a ThunderBolt controller. Additionally, the bike has a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe RoadRunner E-Speed Lowstep is designed to provide enjoyment and ease of use on longer commutes, recreational rides, and adventurous journeys. Its long-range battery, fast ThunderBolt motor, intuitive controller, and road-smoothing suspension seatpost make it the perfect choice for riders seeking both comfort and speed.\n\n## Features\n\nSuper speedy assist\n\nThe ThunderBolt Performance Sport drive unit allows you to accelerate up to 28mph, making errands, commutes, and joyrides a breeze.\n\n## Specs\n\nFrameset\n- Frame: ThunderBolt Smooth Aluminum, Removable Integrated Battery (RIB), sleek welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: RoadRunner Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Hub front: ThunderBolt DC-20, alloy, 6-bolt, 5x100mm QR\n- Skewer front: 132x5mm QR, ThruSkew\n- Hub rear: ThunderBolt DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Skewer rear: 153x5mm bolt-on\n- Rim: ThunderBolt Connection, double-wall, 32-hole, 20 mm width, Schrader valve\n- Tire: ThunderBolt E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: RoadRunner City pedals\n\nComponents\n- Saddle: RoadRunner Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - RoadRunner alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - RoadRunner alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: RoadRunner Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - RoadRunner alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8'', threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: ThunderBolt PowerTube 500Wh\n- Charger: ThunderBolt compact 2A, 100-240V\n- Computer: ThunderBolt Purion\n- Motor: ThunderBolt Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - ThunderBolt SOLO for e-bike, taillight\n - Size: XS, S, M, L - ThunderBolt MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - RoadRunner wide\n - Size: XS, S, M, L - RoadRunner plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 4999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "Hyperdrive Turbo X1", + "shortDescription": "Hyperdrive Turbo X1 is a high-performance electric bike designed for riders seeking an exhilarating experience on their daily rides. It features a powerful and efficient Hyperdrive Sport drive unit and a sleek, integrated 500Wh battery for extended range. This e-bike is equipped with top-of-the-line components prioritizing comfort and safety, including a suspension seatpost, wide and stable tires, and integrated lights.", + "description": "## Overview\n\nIt's right for you if...\nYou crave the thrill of an e-bike that can accelerate rapidly, reaching high speeds effortlessly. You value durability and are looking for a bike that is equipped with one of the fastest motors available.\n\nThe tech you get\nA lightweight Hyper Alloy frame with a lowstep geometry, a Hyperdrive Sport (300W, 70Nm) drive unit capable of maintaining speeds up to 30 mph, a long-range 500Wh battery seamlessly integrated into the frame, and an intuitive Hyper Control controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for enhanced stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThis bike is designed for riders seeking enjoyment and convenience on longer commutes, recreational rides, and thrilling adventures. With its long-range battery, high-speed motor, user-friendly controller, and smooth-riding suspension seatpost, the Hyperdrive Turbo X1 guarantees an exceptional e-biking experience.\n\n## Features\n\nHyperboost Acceleration\nExperience adrenaline-inducing rides with the powerful Hyperdrive Sport drive unit that enables quick acceleration and effortless cruising through errands, commutes, and joyrides.\n\n## Specs\n\nFrameset\nFrame\tHyper Alloy, Removable Integrated Battery (RIB), seamless welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\nFork\tHyper Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\nMax compatible fork travel\t50mm\n\nWheels\nHub front\tFormula DC-20, alloy, 6-bolt, 5x100mm QR\nSkewer front\t132x5mm QR, ThruSkew\nHub rear\tFormula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\nSkewer rear\t153x5mm bolt-on\nRim\tHyper Connection, double-wall, 32-hole, 20 mm width, Schrader valve\nTire\tHyper E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\nMax tire size\t700x50mm with or without fenders\n\nDrivetrain\nShifter\tShimano Deore M4100, 10 speed\nRear derailleur\tShimano Deore M5120, long cage\nCrank\tProWheel alloy, 170mm length\nChainring\tFSA, 42T, steel w/guard\nCassette\tShimano Deore M4100, 11-42, 10 speed\nChain\tKMC E10\nPedal\tHyper City pedals\n\nComponents\nSaddle\tHyper Boulevard\nSeatpost\tAlloy, suspension, 31.6mm, 300mm length\n*Handlebar\tSize: XS, S, M\nHyper alloy, 31.8mm, comfort sweep, 620mm width\nSize: L\nHyper alloy, 31.8mm, comfort sweep, 660mm width\nGrips\tHyper Satellite Elite, alloy lock-on\n*Stem\tSize: XS, S\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\nSize: M, L\nHyper alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\nHeadset\tVP sealed cartridge, 1-1/8'', threaded\nBrake\tShimano MT200 hydraulic disc\n*Brake rotor\tSize: XS, S, M, L\nShimano RT26, 6-bolt,180mm\nSize: XS, S, M, L\nShimano RT26, 6-bolt,160mm\n\nAccessories\nBattery\tHyper PowerTube 500Wh\nCharger\tHyper compact 2A, 100-240V\nComputer\tHyper Control\nMotor\tHyperdrive Sport, 70Nm, 30mph\n*Light\tSize: XS, S, M, L\nSpanninga SOLO for e-bike, taillight\nSize: XS, S, M, L\nHerrmans MR8, 180 lumen, 60 lux, LED, headlight\nKickstand\tAdjustable length rear mount alloy kickstand\nCargo rack\tMIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n*Fender\tSize: XS, S, M, L\nSKS wide\nSize: XS, S, M, L\nSKS plastic\n\nWeight\nWeight\tM - 22.30 kg / 49.17 lbs\nWeight limit\tThis bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\n\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 1999.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + }, + { + "name": "Horizon+ Evo Lowstep", + "shortDescription": "The Horizon+ Evo Lowstep is a versatile electric hybrid bike designed for riders seeking a thrilling and efficient riding experience on a variety of terrains. With its powerful Bosch Performance Line Sport drive unit and integrated 500Wh battery, this e-bike enables riders to cover long distances with ease. Equipped with features prioritizing comfort and safety, such as a suspension seatpost, stable tires, and integrated lights, the Horizon+ Evo Lowstep is a reliable companion for everyday rides.", + "description": "## Overview\n\nIt's right for you if...\nYou desire the convenience and speed of an e-bike to enhance your riding, and you want an intuitive and durable bicycle. You prioritize having one of the fastest motors developed by Bosch.\n\nThe tech you get\nA lightweight Alpha Smooth Aluminum frame with a lowstep geometry, a Bosch Performance Line Sport (250W, 65Nm) drive unit capable of sustaining speeds up to 28 mph, a fully encased 500Wh battery integrated into the frame, and a Bosch Purion controller. Additionally, it features a 10-speed Shimano drivetrain, hydraulic disc brakes for reliable stopping power in all weather conditions, a suspension seatpost, wide puncture-resistant tires for improved stability, ergonomic grips, a kickstand, lights, and a rack and fenders.\n\nThe final word\nThe Horizon+ Evo Lowstep offers an enjoyable and user-friendly riding experience for longer commutes, recreational rides, and adventures. It boasts an extended range battery, a high-performance Bosch motor, an intuitive controller, and a suspension seatpost for a smooth ride on various road surfaces.\n\n## Features\n\nSuper speedy assist\nExperience effortless cruising through errands, commutes, and joyrides with the new Bosch Performance Sport drive unit, allowing acceleration of up to 28 mph.\n\n## Specs\n\nFrameset\n- Frame: Alpha Platinum Aluminum, Removable Integrated Battery (RIB), smooth welds, rack & fender mounts, internal routing, kickstand mount, 135x5mm QR\n- Fork: Horizon Alloy, threaded steel steerer, rack mounts, post mount disc, 460mm axle-to-crown, ThruSkew 5mm QR\n- Max compatible fork travel: 50mm\n\nWheels\n- Front Hub: Formula DC-20, alloy, 6-bolt, 5x100mm QR\n- Front Skewer: 132x5mm QR, ThruSkew\n- Rear Hub: Formula DC-22, alloy, 6-bolt, Shimano 8/9/10 freehub, 135x5mm QR\n- Rear Skewer: 153x5mm bolt-on\n- Rim: Bontrager Connection, double-wall, 32-hole, 20mm width, Schrader valve\n- Tire: Bontrager E6 Hard-Case Lite, reflective, wire bead, 60tpi, 700x50c\n- Max tire size: 700x50mm with or without fenders\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10-speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: ProWheel alloy, 170mm length\n- Chainring: FSA, 42T, steel w/guard\n- Cassette: Shimano Deore M4100, 11-42, 10-speed\n- Chain: KMC E10\n- Pedal: Bontrager City pedals\n\nComponents\n- Saddle: Bontrager Boulevard\n- Seatpost: Alloy, suspension, 31.6mm, 300mm length\n- Handlebar:\n - Size: XS, S, M - Bontrager alloy, 31.8mm, comfort sweep, 620mm width\n - Size: L - Bontrager alloy, 31.8mm, comfort sweep, 660mm width\n- Grips: Bontrager Satellite Elite, alloy lock-on\n- Stem:\n - Size: XS, S - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 85mm length\n - Size: M, L - Bontrager alloy quill, 31.8mm clamp, adjustable rise, Blendr compatible, 105mm length\n- Headset: VP sealed cartridge, 1-1/8\", threaded\n- Brake: Shimano MT200 hydraulic disc\n- Brake rotor:\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 180mm\n - Size: XS, S, M, L - Shimano RT26, 6-bolt, 160mm\n\nAccessories\n- Battery: Bosch PowerTube 500Wh\n- Charger: Bosch compact 2A, 100-240V\n- Computer: Bosch Purion\n- Motor: Bosch Performance Line Sport, 65Nm, 28mph\n- Light:\n - Size: XS, S, M, L - Spanninga SOLO for e-bike, taillight\n - Size: XS, S, M, L - Herrmans MR8, 180 lumen, 60 lux, LED, headlight\n- Kickstand: Adjustable length rear mount alloy kickstand\n- Cargo rack: MIK-compatible alloy rear rack, maximum load 25 kg / 55 lbs\n- Fender:\n - Size: XS, S, M, L - SKS wide\n - Size: XS, S, M, L - SKS plastic\n\nWeight\n- Weight: M - 22.30 kg / 49.17 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| XS | 147 - 155 cm 4'10\" - 5'1\" | 69 - 73 cm 27\" - 29\" |\n| S | 155 - 165 cm 5'1\" - 5'5\" | 72 - 78 cm 28\" - 31\" |\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n\n## Geometry\n\nAll measurements provided in cm unless otherwise noted.\nSizing table\n| Frame size number | 40 cm | 45 cm | 50 cm | 55 cm |\n|---------------------------|-------|-------|-------|-------|\n| Frame size letter | XS | S | M | L |\n| Wheel size | 700c | 700c | 700c | 700c |\n| A — Seat tube | 39.0 | 44.0 | 50.0 | 55.0 |\n| B — Seat tube angle | 71.0° | 71.0° | 71.0° | 71.0° |\n| C — Head tube length | 16.0 | 16.0 | 18.0 | 20.0 |\n| D — Head angle | 68.2° | 68.2° | 68.2° | 68.2° |\n| E — Effective top tube | 58.2 | 58.7 | 59.8 | 61.0 |\n| F — Bottom bracket height | 29.4 | 29.4 | 29.4 | 29.4 |\n| G — Bottom bracket drop | 6.5 | 6.5 | 6.5 | 6.5 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 |\n| J — Trail | 9.5 | 9.5 | 9.5 | 9.5 |\n| K — Wheelbase | 112.2 | 112.7 | 114.0 | 115.2 |\n| L — Standover | 43.3 | 43.3 | 43.3 | 43.3 |\n| M — Frame reach | 36.5 | 38.5 | 38.5 | 38.5 |\n| N — Frame stack | 63.0 | 63.0 | 64.9 | 66.8 |\n| Stem length | 8.5 | 9.0 | 9.0 | 11.0 |", + "price": 4499.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "FastRider X1", + "shortDescription": "FastRider X1 is a high-performance e-bike designed for riders seeking speed and long-distance capabilities. Equipped with a powerful motor and a high-capacity battery, the FastRider X1 is perfect for daily commuters and e-bike enthusiasts. It boasts a sleek and functional design, making it a great alternative to car transportation. The bike also features a smartphone controller for easy navigation and entertainment options.", + "description": "## Overview\nIt's right for you if...\nYou're looking for an e-bike that offers both speed and endurance. The FastRider X1 comes with a high-performance motor and a long-lasting battery, making it ideal for long-distance rides.\n\nThe tech you get\nThe FastRider X1 features a state-of-the-art motor and a spacious battery, ensuring a fast and efficient ride.\n\nThe final word\nWith the powerful motor and long-range battery, the FastRider X1 allows you to cover more distance at higher speeds.\n\n## Features\nConnect Your Ride with the FastRider App\nDownload the FastRider app and transform your smartphone into an on-board computer. Easily dock and charge your phone with the smartphone controller, and use the thumb pad on your handlebar to make calls, listen to music, get turn-by-turn directions, and more. The app also allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nGoodbye, Car. Hello, Extended Range!\nWith the option to add the Range Boost feature, you can attach a second long-range battery to your FastRider X1, doubling the distance and time between charges. This enhancement allows you to ride longer, commute farther, and take on more adventurous routes.\n\nWhat is the range?\nTo estimate the distance you can travel on a single charge, use our range calculator tool. It automatically fills in the variables for this specific bike model and assumes an average rider, but you can adjust the settings to get the most accurate estimate for your needs.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: FastRider rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: FastRider sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: FastRider Switch thru axle, removable lever\n- Rear Hub: FastRider alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: FastRider MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: FastRider E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - FastRider alloy, 170mm length / Size: L, XL - FastRider alloy, 175mm length\n- Chainring: FastRider 46T narrow/wide alloy, w/alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10 / Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - FastRider City pedals / Size: M, L, XL - Wellgo C157, boron axle, plastic body / Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: FastRider Commuter Comp\n- Seatpost: FastRider Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - FastRider alloy, 31.8mm, 15mm rise, 600mm width / Size: L, XL - FastRider alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: FastRider Satellite Elite, alloy lock-on\n- Stem: Size: M - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length / Size: L - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length / Size: XL - FastRider alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom / Size: M, L, XL - FSA Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: FastRider PowerTube 625Wh\n- Charger: FastRider standard 4A, 100-240V\n- Motor: FastRider Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - FastRider taillight, 50 lumens / Size: M, L, XL - FastRider headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy / Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: FastRider integrated rear rack, aluminum\n- Fender: FastRider custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n\nWeight limit\n- This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 5499.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "SonicRide 8S", + "shortDescription": "SonicRide 8S is a high-performance e-bike designed for riders who crave speed and long-distance capabilities. The advanced SonicDrive motor provides powerful assistance up to 28 mph, combined with a durable and long-lasting battery for extended rides. With its sleek design and thoughtful features, the SonicRide 8S is perfect for those who prefer the freedom of riding a bike over driving a car. Plus, it comes equipped with a smartphone controller for easy navigation, music, and more.", + "description": "## Overview\nIt's right for you if...\nYou want a fast and efficient e-bike that can take you long distances. The SonicRide 8S features a hydroformed aluminum frame with a concealed 625Wh battery, a high-powered SonicDrive motor, and a Smartphone Controller. It also includes essential accessories such as lights, fenders, and a rear rack.\n\nThe tech you get\nThe SonicRide 8S is equipped with the fastest SonicDrive motor, ensuring exhilarating rides at high speeds. The long-range battery is perfect for commuters and riders looking to explore new horizons.\n\nThe final word\nWith the SonicDrive motor and long-lasting battery, you can enjoy extended rides at higher speeds.\n\n## Features\n\nConnect Your Ride with SonicRide App\nDownload the SonicRide app and transform your phone into an onboard computer. Simply attach it to the Smartphone Controller for docking and charging. Use the thumb pad on your handlebar to control calls, music, directions, and more. The Bluetooth® wireless technology allows you to connect with fitness and health apps, syncing your routes and ride data.\n\nSay Goodbye to Limited Range with Range Boost!\nExperience the convenience of Range Boost, an additional long-range 500Wh battery that seamlessly attaches to your bike's down tube. This upgrade allows you to double your distance and time between charges, enabling longer commutes and more adventurous rides. Range Boost is compatible with select SonicRide electric bike models.\n\nWhat is the range?\nFor an accurate estimate of how far you can ride on a single charge, use SonicRide's range calculator. We have pre-filled the variables for this specific bike model and the average rider, but you can adjust them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: High-performance hydroformed alloy, Removable Integrated Battery, Range Boost-compatible, internal cable routing, Motor Armour, post-mount disc, 135x5 mm QR\n- Fork: SonicRide rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru axle, post mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: SonicRide sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: SonicRide Switch thru axle, removable lever\n- Rear Hub: SonicRide alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SonicRide MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: Size: M, L, XL - 14g stainless steel, black\n- Tire: SonicRide E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Size: M, L, XL - Shimano Deore M5120, long cage\n- Crank: Size: M - SonicRide alloy, 170mm length; Size: L, XL - SonicRide alloy, 175mm length\n- Chainring: SonicRide 46T narrow/wide alloy, with alloy guard\n- Cassette: Size: M, L, XL - Shimano Deore M4100, 11-42, 10 speed\n- Chain: Size: M, L, XL - KMC E10; Size: M, L, XL - KMC X10e\n- Pedal: Size: M, L, XL - SonicRide City pedals; Size: M, L, XL - Wellgo C157, boron axle, plastic body; Size: M, L, XL - slip-proof aluminum pedals with reflectors\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: SonicRide Commuter Comp\n- Seatpost: SonicRide Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Size: M - SonicRide alloy, 31.8mm, 15mm rise, 600mm width; Size: L, XL - SonicRide alloy, 31.8mm, 15mm rise, 660mm width\n- Grips: SonicRide Satellite Elite, alloy lock-on\n- Stem: Size: M - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length; Size: L - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length; Size: XL - SonicRide alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\n- Headset: Size: M, L, XL - SonicRide IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom; Size: M, L, XL - SonicRide Integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\n- Battery: SonicRide PowerTube 625Wh\n- Charger: SonicRide standard 4A, 100-240V\n- Motor: SonicRide Performance Speed, 85 Nm, 28 mph / 45 kph\n- Light: Size: M, L, XL - SonicRide Lync taillight, 50 lumens; Size: M, L, XL - SonicRide Lync headlight, 500 lumens\n- Kickstand: Size: M, L, XL - Rear mount, alloy; Size: M, L, XL - Adjustable length alloy kickstand\n- Cargo rack: SonicRide integrated rear rack, aluminum\n- Fender: SonicRide custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm / 5'5\" - 5'9\" | 77 - 83 cm / 30\" - 33\" |\n| L | 175 - 186 cm / 5'9\" - 6'1\" | 82 - 88 cm / 32\" - 35\" |\n| XL | 186 - 197 cm / 6'1\" - 6'6\" | 87 - 93 cm / 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |", + "price": 5999.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "SwiftVolt Pro", + "shortDescription": "SwiftVolt Pro is a high-performance e-bike designed for riders seeking a thrilling and fast riding experience. Equipped with a powerful SwiftDrive motor that provides assistance up to 30 mph and a long-lasting battery, this bike is perfect for long-distance commuting and passionate e-bike enthusiasts. The sleek and innovative design features cater specifically to individuals who prioritize cycling over driving. Additionally, the bike is seamlessly integrated with your smartphone, allowing you to use it for navigation, music, and more.", + "description": "## Overview\nThis bike is ideal for you if:\n- You desire a sleek and modern hydroformed aluminum frame that houses a 700Wh battery.\n- You want to maintain high speeds of up to 30 mph with the assistance of the SwiftDrive motor.\n- You appreciate the convenience of using your smartphone as a controller, which can be docked and charged on the handlebar.\n\n## Features\n\nConnect with SwiftSync App\nBy downloading the SwiftSync app, your smartphone becomes an interactive on-board computer. Attach it to the handlebar-mounted controller for easy access and charging. With the thumb pad, you can make calls, listen to music, receive turn-by-turn directions, and connect with fitness and health apps to track your routes and ride data via Bluetooth® wireless technology.\n\nEnhanced Range with BoostMax\nBoostMax offers the capability to attach a second 700Wh Swift battery to the downtube of your bike, effectively doubling the distance and time between charges. This allows for extended rides, longer commutes, and more significant adventures. BoostMax is compatible with select Swift electric bike models.\n\nRange Estimation\nFor an estimate of how far you can ride on a single charge, consult the Swift range calculator. The variables are automatically populated based on this bike model and the average rider, but you can modify them to obtain the most accurate estimate.\n\n## Specifications\nFrameset\n- Frame: Lightweight hydroformed alloy, Removable Integrated Battery, BoostMax-compatible, internal cable routing, post-mount disc, 135x5 mm QR\n- Fork: SwiftVolt rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\n- Max compatible fork travel: 63mm\n\nWheels\n- Front Hub: Swift sealed bearing, 32-hole 15mm alloy thru-axle\n- Front Skewer: Swift Switch thru-axle, removable lever\n- Rear Hub: Swift alloy, sealed bearing, 6-bolt, 135x5mm QR\n- Rear Skewer: 148x5mm bolt-on\n- Rim: SwiftRim, tubeless compatible, 32-hole, 35mm width, Presta valve\n- Spokes: 14g stainless steel, black\n- Tire: Swift E6 Hard-Case Lite, reflective strip, 27.5x2.40''\n- Max tire size: 27.5x2.40\"\n\nDrivetrain\n- Shifter: Shimano Deore M4100, 10 speed\n- Rear Derailleur: Shimano Deore M5120, long cage\n- Crank: Swift alloy, 170mm length\n- Chainring: Swift 46T narrow/wide alloy, w/alloy guard\n- Cassette: Shimano Deore M4100, 11-42, 10 speed\n- Chain: KMC E10\n- Pedal: Swift City pedals\n- Max chainring size: 1x: 48T\n\nComponents\n- Saddle: Swift Commuter Comp\n- Seatpost: Swift Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\n- Handlebar: Swift alloy, 31.8mm, 15mm rise, 600mm width (M), 660mm width (L, XL)\n- Grips: Swift Satellite Elite, alloy lock-on\n- Stem: Swift alloy, 31.8mm, Blendr compatible, 7 degree, 70mm length (M), 90mm length (L), 100mm length (XL)\n- Headset: FSA IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\n- Brakes: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\n- Brake Rotor: Shimano RT56, 6-bolt, 180mm\n- Rotor size: Max 180mm front & rear\n\nAccessories\n- Battery: Swift PowerTube 700Wh\n- Charger: Swift standard 4A, 100-240V\n- Motor: SwiftDrive, 90 Nm, 30 mph / 48 kph\n- Light: Swift Lync taillight, 50 lumens (M, L, XL), Swift Lync headlight, 500 lumens (M, L, XL)\n- Kickstand: Rear mount, alloy (M, L, XL), Adjustable length alloy kickstand (M, L, XL)\n- Cargo rack: SwiftVolt integrated rear rack, aluminum\n- Fender: Swift custom aluminum\n\nWeight\n- Weight: M - 25.54 kg / 56.3 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:-------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 2499.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + }, + { + "name": "AgileEon 9X", + "shortDescription": "AgileEon 9X is a high-performance e-bike designed for riders seeking speed and endurance. Equipped with a robust motor and an extended battery life, this bike is perfect for long-distance commuters and avid e-bike enthusiasts. It boasts innovative features tailored for individuals who prioritize cycling over driving. Additionally, the bike integrates seamlessly with your smartphone, allowing you to access navigation, music, and more.", + "description": "## Overview\nIt's right for you if...\nYou crave speed and want to cover long distances efficiently. The AgileEon 9X features a sleek hydroformed aluminum frame that houses a powerful motor, along with a large-capacity battery for extended rides. It comes equipped with a 10-speed drivetrain, front and rear lighting, fenders, and a rear rack.\n\nThe tech you get\nDesigned for those constantly on the move, this bike includes a state-of-the-art motor and a high-capacity battery, making it an excellent choice for lengthy commutes.\n\nThe final word\nWith the AgileEon 9X, you can push your boundaries and explore new horizons thanks to its powerful motor and long-lasting battery.\n\n## Features\n\nConnect Your Ride with RideMate App\nMake use of the RideMate app to transform your smartphone into an onboard computer. Simply attach it to the RideMate controller to dock and charge, then utilize the thumb pad on your handlebar to make calls, listen to music, receive turn-by-turn directions, and more. The bike also supports Bluetooth® wireless technology, enabling seamless connectivity with fitness and health apps for route syncing and ride data.\n\nGoodbye, car. Hello, Extended Range!\nEnhance your riding experience with the Extended Range option, which allows for the attachment of an additional high-capacity 500Wh battery to your bike's downtube. This doubles the distance and time between charges, enabling longer rides, extended commutes, and more significant adventures. The Extended Range feature is compatible with select AgileEon electric bike models.\n\nWhat is the range?\nTo determine how far you can ride on a single charge, you can utilize the range calculator provided by AgileEon. We have pre-filled the variables for this specific model and an average rider, but adjustments can be made for a more accurate estimation.\n\n## Specifications\nFrameset\nFrame: High-performance hydroformed alloy, Removable Integrated Battery, Extended Range-compatible, internal cable routing, Motor Armor, post-mount disc, 135x5 mm QR\nFork: AgileEon rigid alloy fork, 1-1/8'' steel steerer, 100x15mm thru-axle, post-mount disc brake\nMax compatible fork travel: 63mm\n\nWheels\nFront Hub: AgileEon sealed bearing, 32-hole 15mm alloy thru-axle\nFront Skewer: AgileEon Switch thru-axle, removable lever\nRear Hub: AgileEon alloy, sealed bearing, 6-bolt, 135x5mm QR\nRear Skewer: 148x5mm bolt-on\nRim: AgileEon MD35, tubeless compatible, 32-hole, 35mm width, Presta valve\nSpokes:\n- Size: M, L, XL: 14g stainless steel, black\nTire: AgileEon E6 Hard-Case Lite, reflective strip, 27.5x2.40''\nMax tire size: 27.5x2.40\"\n\nDrivetrain\nShifter: Shimano Deore M4100, 10-speed\nRear derailleur:\n- Size: M, L, XL: Shimano Deore M5120, long cage\nCrank:\n- Size: M: AgileEon alloy, 170mm length\n- Size: L, XL: AgileEon alloy, 175mm length\nChainring: AgileEon 46T narrow/wide alloy, with alloy guard\nCassette:\n- Size: M, L, XL: Shimano Deore M4100, 11-42, 10-speed\nChain:\n- Size: M, L, XL: KMC E10\nPedal:\n- Size: M, L, XL: AgileEon City pedals\nMax chainring size: 1x: 48T\n\nComponents\nSaddle: AgileEon Commuter Comp\nSeatpost: AgileEon Comp, 6061 alloy, 31.6mm, 8mm offset, 330mm length\nHandlebar:\n- Size: M: AgileEon alloy, 31.8mm, 15mm rise, 600mm width\n- Size: L, XL: AgileEon alloy, 31.8mm, 15mm rise, 660mm width\nGrips: AgileEon Satellite Elite, alloy lock-on\nStem:\n- Size: M: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 70mm length\n- Size: L: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 90mm length\n- Size: XL: AgileEon alloy, 31.8mm, Blendr compatible, 7-degree, 100mm length\nHeadset:\n- Size: M, L, XL: AgileEon IS-2 alloy, integrated, sealed cartridge bearing, 1-1/8'' top, 1.5'' bottom\nBrake: Shimano MT520 4-piston hydraulic disc, post-mount, 180mm rotor\nBrake rotor: Shimano RT56, 6-bolt, 180mm\nRotor size: Max brake rotor sizes: 180mm front & rear\n\nAccessories\nBattery: AgileEon PowerTube 625Wh\nCharger: AgileEon standard 4A, 100-240V\nMotor: AgileEon Performance Speed, 85 Nm, 28 mph / 45 kph\nLight:\n- Size: M, L, XL: AgileEon taillight, 50 lumens\n- Size: M, L, XL: AgileEon headlight, 500 lumens\nKickstand:\n- Size: M, L, XL: Rear mount, alloy\n- Size: M, L, XL: Adjustable length alloy kickstand\nCargo rack: AgileEon integrated rear rack, aluminum\nFender: AgileEon custom aluminum\n\nWeight\nWeight: M - 25.54 kg / 56.3 lbs\nWeight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:------------------------:|:--------------------:|\n| M | 165 - 175 cm 5'5\" - 5'9\" | 77 - 83 cm 30\" - 33\" |\n| L | 175 - 186 cm 5'9\" - 6'1\" | 82 - 88 cm 32\" - 35\" |\n| XL | 186 - 197 cm 6'1\" - 6'6\" | 87 - 93 cm 34\" - 37\" |\n\n## Geometry\n| Frame size letter | M | L | XL |\n|---------------------------|-------|-------|-------|\n| Wheel size | 27.5\" | 27.5\" | 27.5\" |\n| A — Seat tube | 44.6 | 49.1 | 53.4 |\n| B — Seat tube angle | 73.0° | 73.0° | 73.0° |\n| C — Head tube length | 16.5 | 19.5 | 23.0 |\n| D — Head angle | 69.5° | 70.0° | 70.5° |\n| E — Effective top tube | 59.5 | 60.7 | 62.2 |\n| F — Bottom bracket height | 29.5 | 29.5 | 29.5 |\n| G — Bottom bracket drop | 6.0 | 6.0 | 6.0 |\n| H — Chainstay length | 48.7 | 48.7 | 48.7 |\n| I — Offset | 4.4 | 4.4 | 4.4 |\n| J — Trail | 8.6 | 8.1 | 7.9 |\n| K — Wheelbase | 114.6 | 115.0 | 116.4 |\n| L — Standover | 79.5 | 83.7 | 87.9 |\n| M — Frame reach | 40.5 | 40.8 | 41.2 |\n| N — Frame stack | 62.3 | 65.2 | 68.8 |", + "price": 3499.99, + "tags": [ + "bicycle", + "road bike", + "professional" + ] + }, + { + "name": "Stealth R1X Pro", + "shortDescription": "Stealth R1X Pro is a high-performance carbon road bike designed for riders who crave speed and exceptional handling. With its aerodynamic tube shaping, disc brakes, and lightweight carbon wheels, the Stealth R1X Pro offers unparalleled performance for competitive road cycling.", + "description": "## Overview\nIt's right for you if...\nYou're a competitive cyclist looking for a road bike that offers superior performance in terms of speed, handling, and aerodynamics. You want a complete package that includes lightweight carbon wheels, without the need for future upgrades.\n\nThe tech you get\nThe Stealth R1X Pro features a lightweight and aerodynamic carbon frame, an advanced carbon fork, high-performance Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes. The bike also comes equipped with cutting-edge Bontrager Aeolus Elite 35 carbon wheels.\n\nThe final word\nThe Stealth R1X Pro stands out with its combination of a fast and aerodynamic frame, high-end drivetrain, and top-of-the-line carbon wheels. Whether you're racing on local roads, participating in pro stage races, or engaging in hill climbing competitions, this bike is a formidable choice that delivers an exceptional riding experience.\n\n## Features\nSleek and aerodynamic design\nThe Stealth R1X Pro's aero tube shapes maximize speed and performance, making it faster on climbs and flats alike. The bike also features a streamlined Aeolus RSL bar/stem for improved front-end aerodynamics.\n\nDesigned for all riders\nThe Stealth R1X Pro is designed to provide an outstanding fit for riders of all genders, body types, riding styles, and abilities. It comes equipped with size-specific components to ensure a comfortable and efficient riding position for competitive riders.\n\n## Specifications\nFrameset\n- Frame: Ultralight carbon frame constructed with high-performance 500 Series ADV Carbon. It features Ride Tuned performance tube optimization, a tapered head tube, internal routing, DuoTrap S compatibility, flat mount disc brake mounts, and a 142x12mm thru axle.\n- Fork: Full carbon fork (Émonda SL) with a tapered carbon steerer, internal brake routing, flat mount disc brake mounts, and a 12x100mm thru axle.\n- Frame fit: H1.5 Race geometry.\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, and a 100x12mm thru axle.\n- Rear wheel: Bontrager Aeolus Elite 35 carbon wheel with a 35mm rim depth, ADV Carbon construction, Tubeless Ready compatibility, Shimano 11/12-speed freehub, and a 142x12mm thru axle.\n- Front skewer: Bontrager Switch thru axle with a removable lever.\n- Rear skewer: Bontrager Switch thru axle with a removable lever.\n- Tire: Bontrager R2 Hard-Case Lite with an aramid bead, 60 tpi, and a size of 700x25c.\n- Maximum tire size: 28mm.\n\nDrivetrain\n- Shifter:\n - Size 47, 50, 52: Shimano Ultegra R8025 with short-reach levers, 11-speed.\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed.\n- Front derailleur: Shimano Ultegra R8000, braze-on.\n- Rear derailleur: Shimano Ultegra R8000, short cage, with a maximum cog size of 30T.\n- Crank:\n - Size 47: Shimano Ultegra R8000 with 52/36 chainrings and a 165mm length.\n - Size 50, 52: Shimano Ultegra R8000 with 52/36 chainrings and a 170mm length.\n - Size 54, 56, 58: Shimano Ultegra R8000 with 52/36 chainrings and a 172.5mm length.\n - Size 60, 62: Shimano Ultegra R8000 with 52/36 chainrings and a 175mm length.\n- Bottom bracket: Praxis T47 threaded bottom bracket with internal bearings.\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed.\n- Chain: Shimano Ultegra HG701, 11-speed.\n- Maximum chainring size: 1x - 50T, 2x - 53/39.\n\nComponents\n- Saddle: Bontrager Aeolus Comp with steel rails and a width of 145mm.\n- Seatpost:\n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap with a 20mm offset and a short length.\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap with a 20mm offset and a tall length.\n- Handlebar:\n - Size 47, 50: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 38cm.\n - Size 52: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 40cm.\n - Size 54, 56, 58: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 42cm.\n - Size 60, 62: Bontrager Elite VR-C alloy handlebar with a 31.8mm clamp, 100mm reach, 124mm drop, and a width of 44cm.\n- Handlebar tape: Bontrager Supertack Perf tape.\n- Stem:\n - Size 47: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 70mm.\n - Size 50: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 80mm.\n - Size 52, 54: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 90mm.\n - Size 56: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 100mm.\n - Size 58, 60, 62: Bontrager Pro alloy stem with a 31.8mm clamp, Blendr compatibility, 7-degree rise, and a length of 110mm.\n- Brake: Shimano Ultegra hydraulic disc brakes with flat mount calipers.\n- Brake rotor: Shimano RT800 with centerlock mounting, 160mm diameter.\n\nWeight\n- Weight: 8.03 kg (17.71 lbs) for the 56cm frame.\n- Weight limit: The bike has a maximum total weight limit (combined weight of the bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\nPlease refer to the table below for the corresponding Stealth R1X Pro frame sizes, recommended rider height range, and inseam measurements:\n\n| Size | Rider Height | Inseam |\n|:----:|:---------------------:|:--------------:|\n| 47 | 152 - 158 cm (5'0\") | 71 - 75 cm |\n| 50 | 158 - 163 cm (5'2\") | 74 - 77 cm |\n| 52 | 163 - 168 cm (5'4\") | 76 - 79 cm |\n| 54 | 168 - 174 cm (5'6\") | 78 - 82 cm |\n| 56 | 174 - 180 cm (5'9\") | 81 - 85 cm |\n| 58 | 180 - 185 cm (5'11\") | 84 - 87 cm |\n| 60 | 185 - 190 cm (6'1\") | 86 - 90 cm |\n| 62 | 190 - 195 cm (6'3\") | 89 - 92 cm |\n\n## Geometry\nThe table below provides the geometry measurements for each frame size of the Stealth R1X Pro:\n\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|-------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", + "price": 2999.99, + "tags": [ + "bicycle", + "mountain bike", + "professional" + ] + }, + { + "name": "Avant SLR 6 Disc Pro", + "shortDescription": "Avant SLR 6 Disc Pro is a high-performance carbon road bike designed for riders who prioritize speed and handling. With its aero tube shaping, disc brakes, and lightweight carbon wheels, it offers the perfect balance of speed and control.", + "description": "## Overview\nIt's right for you if...\nYou're a rider who values exceptional performance on fast group rides and races, and you want a complete package that includes lightweight carbon wheels. The Avant SLR 6 Disc Pro is designed to provide the speed and aerodynamics you need to excel on any road.\n\nThe tech you get\nThe Avant SLR 6 Disc Pro features a lightweight 500 Series ADV Carbon frame and fork, Bontrager Aeolus Elite 35 carbon wheels, a full Shimano Ultegra 11-speed drivetrain, and powerful Ultegra disc brakes.\n\nThe final word\nThe standout feature of this bike is the combination of its aero frame, high-performance drivetrain, and top-quality carbon wheels. Whether you're racing, tackling challenging climbs, or participating in professional stage races, the Avant SLR 6 Disc Pro is a worthy choice that will enhance your performance.\n\n## Features\nAll-new aero design\nThe Avant SLR 6 Disc Pro features innovative aero tube shapes that provide an advantage in all riding conditions, whether it's climbing or riding on flat roads. Additionally, it is equipped with a sleek new Aeolus RSL bar/stem that enhances front-end aero performance.\n\nAwesome bikes for everyone\nThe Avant SLR 6 Disc Pro is designed with the belief that every rider, regardless of gender, body type, riding style, or ability, deserves a great bike. It is equipped with size-specific components that ensure a perfect fit for competitive riders of all genders.\n\n## Specifications\nFrameset\n- Frame: Ultralight 500 Series ADV Carbon, Ride Tuned performance tube optimization, tapered head tube, internal routing, DuoTrap S compatible, flat mount disc, 142x12mm thru axle\n- Fork: Avant SL full carbon, tapered carbon steerer, internal brake routing, flat mount disc, 12x100mm thru axle\n- Frame fit: H1.5 Race\n\nWheels\n- Front wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, 100x12mm thru axle\n- Rear wheel: Bontrager Aeolus Elite 35, ADV Carbon, Tubeless Ready, 35mm rim depth, Shimano 11/12-speed freehub, 142x12mm thru axle\n- Front skewer: Bontrager Switch thru axle, removable lever\n- Rear skewer: Bontrager Switch thru axle, removable lever\n- Tire: Bontrager R2 Hard-Case Lite, aramid bead, 60 tpi, 700x25c\n- Max tire size: 28mm\n\nDrivetrain\n- Shifter: \n - Size 47, 50, 52: Shimano Ultegra R8025, short-reach lever, 11-speed\n - Size 54, 56, 58, 60, 62: Shimano Ultegra R8020, 11-speed\n- Front derailleur: Shimano Ultegra R8000, braze-on\n- Rear derailleur: Shimano Ultegra R8000, short cage, 30T max cog\n- Crank: \n - Size 47: Shimano Ultegra R8000, 52/36, 165mm length\n - Size 50, 52: Shimano Ultegra R8000, 52/36, 170mm length\n - Size 54, 56, 58: Shimano Ultegra R8000, 52/36, 172.5mm length\n - Size 60, 62: Shimano Ultegra R8000, 52/36, 175mm length\n- Bottom bracket: Praxis, T47 threaded, internal bearing\n- Cassette: Shimano Ultegra R8000, 11-30, 11-speed\n- Chain: Shimano Ultegra HG701, 11-speed\n- Max chainring size: 1x: 50T, 2x: 53/39\n\nComponents\n- Saddle: Bontrager Aeolus Comp, steel rails, 145mm width\n- Seatpost: \n - Size 47, 50, 52, 54: Bontrager carbon seatmast cap, 20mm offset, short length\n - Size 56, 58, 60, 62: Bontrager carbon seatmast cap, 20mm offset, tall length\n- Handlebar: \n - Size 47, 50: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 38cm width\n - Size 52: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 40cm width\n - Size 54, 56, 58: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 42cm width\n - Size 60, 62: Bontrager Elite VR-C, alloy, 31.8mm, 100mm reach, 124mm drop, 44cm width\n- Handlebar tape: Bontrager Supertack Perf tape\n- Stem: \n - Size 47: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 70mm length\n - Size 50: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 80mm length\n - Size 52, 54: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 90mm length\n - Size 56: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 100mm length\n - Size 58, 60, 62: Bontrager Pro, 31.8mm, Blendr compatible, 7-degree, 110mm length\n- Brake: Shimano Ultegra hydraulic disc, flat mount\n- Brake rotor: Shimano RT800, centerlock, 160mm\n\nWeight\n- Weight: 56 - 8.03 kg / 17.71 lbs\n- Weight limit: This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 275 pounds (125 kg).\n\n## Sizing\n| Size | Rider Height | Inseam |\n|:----:|:-------------------------:|:--------------------:|\n| 47 | 152 - 158 cm 5'0\" - 5'2\" | 71 - 75 cm 28\" - 30\" |\n| 50 | 158 - 163 cm 5'2\" - 5'4\" | 74 - 77 cm 29\" - 30\" |\n| 52 | 163 - 168 cm 5'4\" - 5'6\" | 76 - 79 cm 30\" - 31\" |\n| 54 | 168 - 174 cm 5'6\" - 5'9\" | 78 - 82 cm 31\" - 32\" |\n| 56 | 174 - 180 cm 5'9\" - 5'11\" | 81 - 85 cm 32\" - 33\" |\n| 58 | 180 - 185 cm 5'11\" - 6'1\" | 84 - 87 cm 33\" - 34\" |\n| 60 | 185 - 190 cm 6'1\" - 6'3\" | 86 - 90 cm 34\" - 35\" |\n| 62 | 190 - 195 cm 6'3\" - 6'5\" | 89 - 92 cm 35\" - 36\" |\n\n## Geometry\n| Frame size number | 47 cm | 50 cm | 52 cm | 54 cm | 56 cm | 58 cm | 60 cm | 62 cm |\n|---------------------------------------|-------|-------|-------|-------|-------|-------|-------|-------|\n| Wheel size | 700c | 700c | 700c | 700c | 700c | 700c | 700c | 700c |\n| A — Seat tube | 42.4 | 45.3 | 48.3 | 49.6 | 52.5 | 55.3 | 57.3 | 59.3 |\n| B — Seat tube angle | 74.6° | 74.6° | 74.2° | 73.7° | 73.3° | 73.0° | 72.8° | 72.5° |\n| C — Head tube length | 10.0 | 11.1 | 12.1 | 13.1 | 15.1 | 17.1 | 19.1 | 21.1 |\n| D — Head angle | 72.1° | 72.1° | 72.8° | 73.0° | 73.5° | 73.8° | 73.9° | 73.9° |\n| E — Effective top tube | 51.2 | 52.1 | 53.4 | 54.3 | 55.9 | 57.4 | 58.6 | 59.8 |\n| G — Bottom bracket drop | 7.2 | 7.2 | 7.2 | 7.0 | 7.0 | 6.8 | 6.8 | 6.8 |\n| H — Chainstay length | 41.0 | 41.0 | 41.0 | 41.0 | 41.0 | 41.1 | 41.1 | 41.2 |\n| I — Offset | 4.5 | 4.5 | 4.5 | 4.5 | 4.0 | 4.0 | 4.0 | 4.0 |\n| J — Trail | 6.8 | 6.2 | 5.8 | 5.6 | 5.8 | 5.7 | 5.6 | 5.6 |\n| K — Wheelbase | 97.2 | 97.4 | 97.7 | 98.1 | 98.3 | 99.2 | 100.1 | 101.0 |\n| L — Standover | 69.2 | 71.1 | 73.2 | 74.4 | 76.8 | 79.3 | 81.1 | 82.9 |\n| M — Frame reach | 37.3 | 37.8 | 38.3 | 38.6 | 39.1 | 39.6 | 39.9 | 40.3 |\n| N — Frame stack | 50.7 | 52.1 | 53.3 | 54.1 | 56.3 | 58.1 | 60.1 | 62.0 |\n| Saddle rail height min (w/short mast) | 55.5 | 58.5 | 61.5 | 64.0 | 67.0 | 69.0 | 71.0 | 73.0 |\n| Saddle rail height max (w/short mast) | 61.5 | 64.5 | 67.5 | 70.0 | 73.0 | 75.0 | 77.0 | 79.0 |\n| Saddle rail height min (w/tall mast) | 59.0 | 62.0 | 65.0 | 67.5 | 70.5 | 72.5 | 74.5 | 76.5 |\n| Saddle rail height max (w/tall mast) | 65.0 | 68.0 | 71.0 | 73.5 | 76.5 | 78.5 | 80.5 | 82.5 |", + "price": 999.99, + "tags": [ + "bicycle", + "city bike", + "professional" + ] + } +] diff --git a/spring-ai-core/src/test/resources/logback.xml b/spring-ai-core/src/test/resources/logback.xml index 7030fdba8..d8a39b52c 100644 --- a/spring-ai-core/src/test/resources/logback.xml +++ b/spring-ai-core/src/test/resources/logback.xml @@ -1,16 +1,32 @@ + + - - - %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n - - + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n + + - - - - - - + + + + + + - \ No newline at end of file + diff --git a/spring-ai-docs/pom.xml b/spring-ai-docs/pom.xml index c82441a43..69f8952b3 100644 --- a/spring-ai-docs/pom.xml +++ b/spring-ai-docs/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-docs/src/assembly/javadocs.xml b/spring-ai-docs/src/assembly/javadocs.xml index bde989c22..709caaddc 100644 --- a/spring-ai-docs/src/assembly/javadocs.xml +++ b/spring-ai-docs/src/assembly/javadocs.xml @@ -1,3 +1,19 @@ + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg b/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg index 36f90f818..256b5924f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg +++ b/spring-ai-docs/src/main/antora/modules/ROOT/images/no.svg @@ -1,22 +1,41 @@ - - + + + + cancel Created with Sketch. - - - - + + + + + + - - + + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg index 98fb78a05..ab80895aa 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg +++ b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-integration-diagram-3.svg @@ -1,21 +1,37 @@ + + + + + version="1.1" + id="svg1" + width="1022" + height="239.33333" + viewBox="0 0 1022 239.33333" + sodipodi:docname="spring_ai_logo copy.svg" + inkscape:version="1.3.2 (091e20e, 2023-11-25)" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns="http://www.w3.org/2000/svg" +> + + + \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc index 7069b6ced..6e686b216 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc @@ -18,7 +18,7 @@ var chatClient = ChatClient.builder(chatModel) ) .build(); -String response = chatClient.prompt() +String response = this.chatClient.prompt() // Set advisor parameters at runtime .advisors(advisor -> advisor.param("chat_memory_conversation_id", "678") .param("chat_memory_response_size", 100)) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc index 6efd46764..e41e8eb53 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc @@ -97,7 +97,7 @@ The `RateLimit` instance can be acquired from the `GenerationMetadata`, like so: ---- RateLimit rateLimit = generationMetadata.getRateLimit(); -Long tokensRemaining = rateLimit.getTokensRemaining(); +Long tokensRemaining = this.rateLimit.getTokensRemaining(); // do something interesting with the RateLimit metadata ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc index 2ed3840e8..9022ce19d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/speech/openai-speech.adoc @@ -95,8 +95,8 @@ OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withSpeed(1.0f) .build(); -SpeechPrompt speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); -SpeechResponse response = openAiAudioSpeechModel.call(speechPrompt); +SpeechPrompt speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", this.speechOptions); +SpeechResponse response = openAiAudioSpeechModel.call(this.speechPrompt); ---- == Manual Configuration @@ -128,7 +128,7 @@ Next, create an `OpenAiAudioSpeechModel`: ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(openAiAudioApi); +var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(this.openAiAudioApi); var speechOptions = OpenAiAudioSpeechOptions.builder() .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) @@ -136,13 +136,13 @@ var speechOptions = OpenAiAudioSpeechOptions.builder() .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) .build(); -var speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); -SpeechResponse response = openAiAudioSpeechModel.call(speechPrompt); +var speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", this.speechOptions); +SpeechResponse response = this.openAiAudioSpeechModel.call(this.speechPrompt); // Accessing metadata (rate limit info) -OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); +OpenAiAudioSpeechResponseMetadata metadata = this.response.getMetadata(); -byte[] responseAsBytes = response.getResult().getOutput(); +byte[] responseAsBytes = this.response.getResult().getOutput(); ---- == Streaming Real-time Audio @@ -153,7 +153,7 @@ The Speech API provides support for real-time audio streaming using chunk transf ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(openAiAudioApi); +var openAiAudioSpeechModel = new OpenAiAudioSpeechModel(this.openAiAudioApi); OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) @@ -162,9 +162,9 @@ OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) .build(); -SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); +SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", this.speechOptions); -Flux responseStream = openAiAudioSpeechModel.stream(speechPrompt); +Flux responseStream = this.openAiAudioSpeechModel.stream(this.speechPrompt); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc index a0cc7d4d1..40b283e85 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/azure-openai-transcriptions.adoc @@ -66,10 +66,10 @@ AzureOpenAiAudioTranscriptionOptions transcriptionOptions = AzureOpenAiAudioTran .withLanguage("en") .withPrompt("Ask not this, but ask that") .withTemperature(0f) - .withResponseFormat(responseFormat) + .withResponseFormat(this.responseFormat) .build(); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = azureOpenAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = azureOpenAiTranscriptionModel.call(this.transcriptionRequest); ---- == Manual Configuration @@ -104,7 +104,7 @@ var openAIClient = new OpenAIClientBuilder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .buildClient(); -var azureOpenAiAudioTranscriptionModel = new AzureOpenAiAudioTranscriptionModel(openAIClient, null); +var azureOpenAiAudioTranscriptionModel = new AzureOpenAiAudioTranscriptionModel(this.openAIClient, null); var transcriptionOptions = AzureOpenAiAudioTranscriptionOptions.builder() .withResponseFormat(TranscriptResponseFormat.TEXT) @@ -113,6 +113,6 @@ var transcriptionOptions = AzureOpenAiAudioTranscriptionOptions.builder() var audioFile = new FileSystemResource("/path/to/your/resource/speech/jfk.flac"); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = azureOpenAiAudioTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = this.azureOpenAiAudioTranscriptionModel.call(this.transcriptionRequest); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc index 5bfde4f53..da3119376 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/audio/transcriptions/openai-transcriptions.adoc @@ -94,10 +94,10 @@ OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionO .withLanguage("en") .withPrompt("Ask not this, but ask that") .withTemperature(0f) - .withResponseFormat(responseFormat) + .withResponseFormat(this.responseFormat) .build(); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = openAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = openAiTranscriptionModel.call(this.transcriptionRequest); ---- == Manual Configuration @@ -129,7 +129,7 @@ Next, create a `OpenAiAudioTranscriptionModel` ---- var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); -var openAiAudioTranscriptionModel = new OpenAiAudioTranscriptionModel(openAiAudioApi); +var openAiAudioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi); var transcriptionOptions = OpenAiAudioTranscriptionOptions.builder() .withResponseFormat(TranscriptResponseFormat.TEXT) @@ -138,8 +138,8 @@ var transcriptionOptions = OpenAiAudioTranscriptionOptions.builder() var audioFile = new FileSystemResource("/path/to/your/resource/speech/jfk.flac"); -AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); -AudioTranscriptionResponse response = openAiTranscriptionModel.call(transcriptionRequest); +AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(this.audioFile, this.transcriptionOptions); +AudioTranscriptionResponse response = openAiTranscriptionModel.call(this.transcriptionRequest); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index 7e8897a6e..09b24e876 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -161,9 +161,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] imageData = new ClassPathResource("/multimodal.test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see on this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.imageData))); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); logger.info(response.getResult().getOutput().getContent()); ---- @@ -219,13 +219,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -261,18 +261,18 @@ Next, create a `AnthropicChatModel` and use it for text generations: ---- var anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); -var chatModel = new AnthropicChatModel(anthropicApi, +var chatModel = new AnthropicChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withModel("claude-3-opus-20240229") .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -300,14 +300,14 @@ AnthropicMessage chatCompletionMessage = new AnthropicMessage( List.of(new ContentBlock("Tell me a Joke?")), Role.USER); // Sync request -ResponseEntity response = anthropicApi +ResponseEntity response = this.anthropicApi .chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), - List.of(chatCompletionMessage), null, 100, 0.8, false)); + List.of(this.chatCompletionMessage), null, 100, 0.8, false)); // Streaming request -Flux response = anthropicApi +Flux response = this.anthropicApi .chatCompletionStream(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), - List.of(chatCompletionMessage), null, 100, 0.8, true)); + List.of(this.chatCompletionMessage), null, 100, 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java[AnthropicApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index af9dda949..fa3ffe5e6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -199,7 +199,7 @@ Below is a code example excerpted from link:https://github.com/spring-projects/s URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); String response = ChatClient.create(chatModel).prompt() .options(AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").build()) - .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) + .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, this.url)) .call() .content(); ---- @@ -231,7 +231,7 @@ String response = ChatClient.create(chatModel).prompt() .options(AzureOpenAiChatOptions.builder() .withDeploymentName("gpt-4o").build()) .user(u -> u.text("Explain what do you see on this picture?") - .media(MimeTypeUtils.IMAGE_PNG, resource)) + .media(MimeTypeUtils.IMAGE_PNG, this.resource)) .call() .content(); ---- @@ -270,13 +270,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -322,13 +322,13 @@ var openAIChatOptions = AzureOpenAiChatOptions.builder() .withMaxTokens(200) .build(); -var chatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions); +var chatModel = new AzureOpenAiChatModel(this.openAIClient, this.openAIChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc index c6bd2bce3..5e88b0a24 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc @@ -155,13 +155,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -202,7 +202,7 @@ AnthropicChatBedrockApi anthropicApi = new AnthropicChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(anthropicApi, +BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -211,11 +211,11 @@ BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(anthropicApi .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -244,11 +244,11 @@ AnthropicChatRequest request = AnthropicChatRequest .build(); // Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); +AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(this.request); // Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.anthropicChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc index b62a580cf..848a532b4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc @@ -133,9 +133,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see o this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.imageData))); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); ---- @@ -196,13 +196,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -243,7 +243,7 @@ Anthropic3ChatBedrockApi anthropicApi = new Anthropic3ChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(anthropicApi, +BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(this.anthropicApi, AnthropicChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -252,11 +252,11 @@ BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(anthropicA .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -281,11 +281,11 @@ AnthropicChatRequest request = AnthropicChatRequest .build(); // Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); +AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(this.request); // Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.anthropicChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc index b7eac40d5..b3df6b5d7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc @@ -147,13 +147,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -193,7 +193,7 @@ CohereChatBedrockApi api = new CohereChatBedrockApi(CohereChatModel.COHERE_COMMA new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockCohereChatModel chatModel = new BedrockCohereChatModel(api, +BedrockCohereChatModel chatModel = new BedrockCohereChatModel(this.api, BedrockCohereChatOptions.builder() .withTemperature(0.6) .withTopK(10) @@ -201,11 +201,11 @@ BedrockCohereChatModel chatModel = new BedrockCohereChatModel(api, .withMaxTokens(678) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -242,7 +242,7 @@ var request = CohereChatRequest .withTruncate(Truncate.NONE) .build(); -CohereChatResponse response = cohereChatApi.chatCompletion(request); +CohereChatResponse response = this.cohereChatApi.chatCompletion(this.request); var request = CohereChatRequest .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") @@ -258,8 +258,8 @@ var request = CohereChatRequest .withTruncate(Truncate.NONE) .build(); -Flux responseStream = cohereChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.cohereChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc index ca29f165e..5ddad5775 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc @@ -140,7 +140,7 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } @@ -181,13 +181,13 @@ Ai21Jurassic2ChatBedrockApi api = new Ai21Jurassic2ChatBedrockApi(Ai21Jurassic2C new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(api, +BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(this.api, BedrockAi21Jurassic2ChatOptions.builder() .withTemperature(0.5) .withMaxTokens(100) .withTopP(0.9).build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -214,7 +214,7 @@ Ai21Jurassic2ChatRequest request = Ai21Jurassic2ChatRequest.builder("Hello, my n .withMaxTokens(20) .build(); -Ai21Jurassic2ChatResponse response = jurassic2ChatApi.chatCompletion(request); +Ai21Jurassic2ChatResponse response = this.jurassic2ChatApi.chatCompletion(this.request); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc index a51ca3408..445e4423a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc @@ -145,13 +145,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -191,17 +191,17 @@ LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA2_70B_CHAT new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(api, +BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(this.api, BedrockLlamaChatOptions.builder() .withTemperature(0.5) .withMaxGenLen(100) .withTopP(0.9).build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -230,11 +230,11 @@ LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") .withMaxGenLen(20) .build(); -LlamaChatResponse response = llamaChatApi.chatCompletion(request); +LlamaChatResponse response = this.llamaChatApi.chatCompletion(this.request); // Streaming response -Flux responseStream = llamaChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); +Flux responseStream = this.llamaChatApi.chatCompletionStream(this.request); +List responses = this.responseStream.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc index 9f34fe50b..970963d45 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc @@ -143,13 +143,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -190,18 +190,18 @@ TitanChatBedrockApi titanApi = new TitanChatBedrockApi( new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockTitanChatModel chatModel = new BedrockTitanChatModel(titanApi, +BedrockTitanChatModel chatModel = new BedrockTitanChatModel(this.titanApi, BedrockTitanChatOptions.builder() .withTemperature(0.6) .withTopP(0.8) .withMaxTokenCount(100) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -229,11 +229,11 @@ TitanChatRequest titanChatRequest = TitanChatRequest.builder("Give me the names .withStopSequences(List.of("|")) .build(); -TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); +TitanChatResponse response = this.titanBedrockApi.chatCompletion(this.titanChatRequest); -Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); +Flux response = this.titanBedrockApi.chatCompletionStream(this.titanChatRequest); -List results = response.collectList().block(); +List results = this.response.collectList().block(); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java[TitanChatBedrockApi]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index eecb5cb53..793e8940a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -153,7 +153,7 @@ AnthropicChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), AnthropicChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -180,7 +180,7 @@ var promptOptions = AnthropicChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc index a078e9f9c..93810865c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc @@ -148,7 +148,7 @@ AzureOpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), AzureOpenAiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -185,7 +185,7 @@ var promptOptions = AzureOpenAiChatOptions.builder() .build())) .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc index 039c3f8d9..3fbeef622 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/minimax-chat-functions.adoc @@ -153,7 +153,7 @@ MiniMaxChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), MiniMaxChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = MiniMaxChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc index a0bf2ed38..0cf4e0589 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc @@ -151,7 +151,7 @@ MistralAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, MistralAiChatOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -178,7 +178,7 @@ var promptOptions = MistralAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc index 23a2f9017..fe04e62ea 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/moonshot-chat-functions.adoc @@ -153,7 +153,7 @@ MoonshotChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), MoonshotChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = MoonshotChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc index 652b69285..0e7ba4e90 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc @@ -155,7 +155,7 @@ OllamaChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, OllamaOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -191,7 +191,7 @@ var promptOptions = OllamaOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc index 851f019bc..1468380af 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc @@ -148,7 +148,7 @@ OpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withFunction("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); @@ -184,7 +184,7 @@ var promptOptions = OpenAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(userMessage, promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. @@ -254,7 +254,7 @@ BiFunction OpenAiChatOptions options = OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(weatherFunction) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(this.weatherFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build())) @@ -269,7 +269,7 @@ You can then use these options when making a call to the chat model: [source,java] ---- UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), options)); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage), options)); ---- This approach allows you to pass session-specific or user-specific information to your functions, enabling more contextual and personalized responses. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc index bf06d1257..b4aec52ca 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc @@ -156,7 +156,7 @@ VertexAiGeminiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), VertexAiGeminiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -194,7 +194,7 @@ var promptOptions = VertexAiGeminiChatOptions.builder() .build())) .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc index c62d80616..30f425786 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc @@ -153,7 +153,7 @@ ZhiPuAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), ZhiPuAiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); @@ -190,7 +190,7 @@ var promptOptions = ZhiPuAiChatOptions.builder() new MockWeatherService()))) // function code .build(); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ---- NOTE: The in-prompt registered functions are enabled by default for the duration of this request. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc index cd622ef9d..0550b4fec 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc @@ -261,13 +261,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -307,14 +307,14 @@ var openAiChatOptions = OpenAiChatOptions.builder() .withTemperature(0.4) .withMaxTokens(200) .build(); -var chatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); +var chatModel = new OpenAiChatModel(this.openAiApi, this.openAiChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc index 459e5eb5c..3cd180140 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc @@ -95,7 +95,7 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } ---- @@ -131,7 +131,7 @@ Next, create a `HuggingfaceChatModel` and use it for text generations: ---- HuggingfaceChatModel chatModel = new HuggingfaceChatModel(apiKey, url); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); System.out.println(response.getGeneration().getResult().getOutput().getContent()); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc index 160b63ede..2cacdd0f9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/minimax-chat.adoc @@ -161,13 +161,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -203,17 +203,17 @@ Next, create a `MiniMaxChatModel` and use it for text generations: ---- var miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); -var chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder() +var chatModel = new MiniMaxChatModel(this.miniMaxApi, MiniMaxChatOptions.builder() .withModel(MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -235,12 +235,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = miniMaxApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, false)); +ResponseEntity response = this.miniMaxApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, false)); // Streaming request -Flux streamResponse = miniMaxApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, true)); +Flux streamResponse = this.miniMaxApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.getValue(), 0.7f, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java[MiniMaxApi.java]'s JavaDoc for further information. @@ -259,21 +259,21 @@ Here is a simple snippet how to use the web search: UserMessage userMessage = new UserMessage( "How many gold medals has the United States won in total at the 2024 Olympics?"); -List messages = new ArrayList<>(List.of(userMessage)); +List messages = new ArrayList<>(List.of(this.userMessage)); List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); MiniMaxChatOptions options = MiniMaxChatOptions.builder() .withModel(MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.value) - .withTools(functionTool) + .withTools(this.functionTool) .build(); // Sync request -ChatResponse response = chatModel.call(new Prompt(messages, options)); +ChatResponse response = chatModel.call(new Prompt(this.messages, this.options)); // Streaming request -Flux streamResponse = chatModel.stream(new Prompt(messages, options)); +Flux streamResponse = chatModel.stream(new Prompt(this.messages, this.options)); ---- ==== MiniMaxApi Samples diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc index e0dd67906..6496f094f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc @@ -179,13 +179,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -221,17 +221,17 @@ Next, create a `MistralAiChatModel` and use it for text generations: ---- var mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); -var chatModel = new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder() +var chatModel = new MistralAiChatModel(this.mistralAiApi, MistralAiChatOptions.builder() .withModel(MistralAiApi.ChatModel.LARGE.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -252,12 +252,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = mistralAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, false)); +ResponseEntity response = this.mistralAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, false)); // Streaming request -Flux streamResponse = mistralAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, true)); +Flux streamResponse = this.mistralAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MistralAiApi.ChatModel.LARGE.getValue(), 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java[MistralAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc index 36401bd13..a08478f38 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc @@ -160,13 +160,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -202,17 +202,17 @@ Next, create a `MoonshotChatModel` and use it for text generations: ---- var moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY")); -var chatModel = new MoonshotChatModel(moonshotApi, MoonshotChatOptions.builder() +var chatModel = new MoonshotChatModel(this.moonshotApi, MoonshotChatOptions.builder() .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -234,12 +234,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = moonshotApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, false)); +ResponseEntity response = this.moonshotApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = moonshotApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, true)); +Flux streamResponse = this.moonshotApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java[MoonshotApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc index 406422c25..17bc1204a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc @@ -240,13 +240,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index 71e30665b..8e915a0ae 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -251,9 +251,9 @@ Below is a straightforward code example excerpted from link:https://github.com/s var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage("Explain what do you see on this picture?", - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OllamaOptions.builder().withModel(OllamaModel.LLAVA)).build()); ---- @@ -317,13 +317,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } @@ -369,16 +369,16 @@ Next, create an `OllamaChatModel` instance and use it to send requests for text ---- var ollamaApi = new OllamaApi(); -var chatModel = new OllamaChatModel(ollamaApi, +var chatModel = new OllamaChatModel(this.ollamaApi, OllamaOptions.create() .withModel(OllamaOptions.DEFAULT_MODEL) .withTemperature(0.9)); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -414,7 +414,7 @@ var request = ChatRequest.builder("orca-mini") .withOptions(OllamaOptions.create().withTemperature(0.9)) .build(); -ChatResponse response = ollamaApi.chat(request); +ChatResponse response = this.ollamaApi.chat(this.request); // Streaming request var request2 = ChatRequest.builder("orca-mini") @@ -425,5 +425,5 @@ var request2 = ChatRequest.builder("orca-mini") .withOptions(OllamaOptions.create().withTemperature(0.9).toMap()) .build(); -Flux streamingResponse = ollamaApi.streamingChat(request2); +Flux streamingResponse = this.ollamaApi.streamingChat(this.request2); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 88e6a4b6d..8ee2e8d4f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -178,9 +178,9 @@ Below is a code example excerpted from link:https://github.com/spring-projects/s var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage("Explain what do you see on this picture?", - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); ---- @@ -194,7 +194,7 @@ var userMessage = new UserMessage("Explain what do you see on this picture?", new Media(MimeTypeUtils.IMAGE_PNG, "https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")); -ChatResponse response = chatModel.call(new Prompt(userMessage, +ChatResponse response = chatModel.call(new Prompt(this.userMessage, OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())); ---- @@ -258,10 +258,10 @@ String jsonSchema = """ Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", OpenAiChatOptions.builder() .withModel(ChatModel.GPT_4_O_MINI) - .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema)) + .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, this.jsonSchema)) .build()); -ChatResponse response = this.openAiChatModel.call(prompt); +ChatResponse response = this.openAiChatModel.call(this.prompt); ---- NOTE: Adhere to the OpenAI link:https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[subset of the JSON Schema language] format. @@ -288,18 +288,18 @@ record MathReasoning( var outputConverter = new BeanOutputConverter<>(MathReasoning.class); -var jsonSchema = outputConverter.getJsonSchema(); +var jsonSchema = this.outputConverter.getJsonSchema(); Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", OpenAiChatOptions.builder() .withModel(ChatModel.GPT_4_O_MINI) - .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema)) + .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, this.jsonSchema)) .build()); -ChatResponse response = this.openAiChatModel.call(prompt); -String content = response.getResult().getOutput().getContent(); +ChatResponse response = this.openAiChatModel.call(this.prompt); +String content = this.response.getResult().getOutput().getContent(); -MathReasoning mathReasoning = outputConverter.convert(content); +MathReasoning mathReasoning = this.outputConverter.convert(this.content); ---- NOTE: Ensure you use the `@JsonProperty(required = true,...)` annotation. @@ -353,13 +353,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -399,13 +399,13 @@ var openAiChatOptions = OpenAiChatOptions.builder() .withTemperature(0.4) .withMaxTokens(200) .build(); -var chatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); +var chatModel = new OpenAiChatModel(this.openAiApi, this.openAiChatOptions); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux response = chatModel.stream( +Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -431,12 +431,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = openAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); +ResponseEntity response = this.openAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), "gpt-3.5-turbo", 0.8, false)); // Streaming request -Flux streamResponse = openAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); +Flux streamResponse = this.openAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), "gpt-3.5-turbo", 0.8, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java[OpenAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc index 43bc86c48..19994a4f4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc @@ -164,13 +164,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatClient.call(message)); + return Map.of("generation", this.chatClient.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatClient.stream(prompt); + return this.chatClient.stream(prompt); } } ---- @@ -206,17 +206,17 @@ Next, create a `QianFanChatModel` and use it for text generations: ---- var qianFanApi = new QianFanApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); -var chatClient = new QianFanChatModel(qianFanApi, QianFanChatOptions.builder() +var chatClient = new QianFanChatModel(this.qianFanApi, QianFanChatOptions.builder() .withModel(QianFanApi.ChatModel.ERNIE_Speed_8K.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatClient.call( +ChatResponse response = this.chatClient.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatClient.stream( +Flux streamResponse = this.chatClient.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -240,12 +240,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = qianFanApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, false)); +ResponseEntity response = this.qianFanApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), this.systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = qianFanApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, true)); +Flux streamResponse = this.qianFanApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), this.systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java[QianFanApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 9e6c238b3..58ac42801 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -139,9 +139,9 @@ Below is a simple code example extracted from https://github.com/spring-projects byte[] data = new ClassPathResource("/vertex-test.png").getContentAsByteArray(); var userMessage = new UserMessage("Explain what do you see on this picture?", - List.of(new Media(MimeTypeUtils.IMAGE_PNG, data))); + List.of(new Media(MimeTypeUtils.IMAGE_PNG, this.data))); -ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); +ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage))); ---- == Sample Controller @@ -177,13 +177,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -219,13 +219,13 @@ Next, create a `VertexAiGeminiChatModel` and use it for text generations: ---- VertexAI vertexApi = new VertexAI(projectId, location); -var chatModel = new VertexAiGeminiChatModel(vertexApi, +var chatModel = new VertexAiGeminiChatModel(this.vertexApi, VertexAiGeminiChatOptions.builder() .withModel(ChatModel.GEMINI_PRO_1_5_PRO) .withTemperature(0.4) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc index 7d9bcbb0a..017f4be42 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-palm2-chat.adoc @@ -136,13 +136,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -178,12 +178,12 @@ Next, create a `VertexAiPaLm2ChatModel` and use it for text generations: ---- VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); -var chatModel = new VertexAiPaLm2ChatModel(vertexAiApi, +var chatModel = new VertexAiPaLm2ChatModel(this.vertexAiApi, VertexAiPaLm2ChatOptions.builder() .withTemperature(0.4) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -207,15 +207,15 @@ VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); // Generate var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?"))); -GenerateMessageRequest request = new GenerateMessageRequest(prompt); +GenerateMessageRequest request = new GenerateMessageRequest(this.prompt); -GenerateMessageResponse response = vertexAiApi.generateMessage(request); +GenerateMessageResponse response = this.vertexAiApi.generateMessage(this.request); // Embed text -Embedding embedding = vertexAiApi.embedText("Hello, how are you?"); +Embedding embedding = this.vertexAiApi.embedText("Hello, how are you?"); // Batch embedding -List embeddings = vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); +List embeddings = this.vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc index 9c63162d9..5f0727dcf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc @@ -119,7 +119,7 @@ public class MyClass { Prompt prompt = new Prompt(new SystemMessage(userInput), options); - var results = chatModel.call(prompt); + var results = this.chatModel.call(prompt); var generatedText = results.getResult().getOutput().getContent(); @@ -135,7 +135,7 @@ public class MyClass { Prompt prompt = new Prompt(new SystemMessage(userInput), options); - var results = chatModel.stream(prompt).collectList().block(); // wait till the stream is resolved (completed) + var results = this.chatModel.stream(prompt).collectList().block(); // wait till the stream is resolved (completed) var generatedText = results.stream() .map(generation -> generation.getResult().getOutput().getContent()) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc index 7f417336d..9a178bbe6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -162,13 +162,13 @@ public class ChatController { @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } @GetMapping("/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { var prompt = new Prompt(new UserMessage(message)); - return chatModel.stream(prompt); + return this.chatModel.stream(prompt); } } ---- @@ -204,17 +204,17 @@ Next, create a `ZhiPuAiChatModel` and use it for text generations: ---- var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); -var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder() +var chatModel = new ZhiPuAiChatModel(this.zhiPuAiApi, ZhiPuAiChatOptions.builder() .withModel(ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue()) .withTemperature(0.4) .withMaxTokens(200) .build()); -ChatResponse response = chatModel.call( +ChatResponse response = this.chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); // Or with streaming responses -Flux streamResponse = chatModel.stream( +Flux streamResponse = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- @@ -236,12 +236,12 @@ ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); // Sync request -ResponseEntity response = zhiPuAiApi.chatCompletionEntity( - new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, false)); +ResponseEntity response = this.zhiPuAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, false)); // Streaming request -Flux streamResponse = zhiPuAiApi.chatCompletionStream( - new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, true)); +Flux streamResponse = this.zhiPuAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(this.chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7, true)); ---- Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java[ZhiPuAiApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index b623e01ba..2ba0f6f62 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -57,11 +57,11 @@ Then, create a `ChatClient.Builder` instance programmatically for every `ChatMod ---- ChatModel myChatModel = ... // usually autowired -ChatClient.Builder builder = ChatClient.builder(myChatModel); +ChatClient.Builder builder = ChatClient.builder(this.myChatModel); // or create a ChatClient with the default builder settings: -ChatClient chatClient = ChatClient.create(myChatModel); +ChatClient chatClient = ChatClient.create(this.myChatModel); ---- == ChatClient Fluent API @@ -156,13 +156,13 @@ Flux flux = this.chatClient.prompt() Generate the filmography for a random actor. {format} """) - .param("format", converter.getFormat())) + .param("format", this.converter.getFormat())) .stream() .content(); -String content = flux.collectList().block().stream().collect(Collectors.joining()); +String content = this.flux.collectList().block().stream().collect(Collectors.joining()); -List actorFilms = converter.convert(content); +List actorFilms = this.converter.convert(this.content); ---- == call() return values @@ -224,7 +224,7 @@ class AIController { @GetMapping("/ai/simple") public Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("completion", chatClient.prompt().user(message).call().content()); + return Map.of("completion", this.chatClient.prompt().user(message).call().content()); } } ---- @@ -268,7 +268,7 @@ class AIController { @GetMapping("/ai") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message, String voice) { return Map.of("completion", - chatClient.prompt() + this.chatClient.prompt() .system(sp -> sp.param("voice", voice)) .user(message) .call() @@ -397,7 +397,7 @@ ChatClient chatClient = ChatClient.builder(chatModel) .build(); // Update filter expression at runtime -String content = chatClient.prompt() +String content = this.chatClient.prompt() .user("Please answer my question XYZ") .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'")) .call() diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc index f4620cd5c..f70999a1f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc @@ -197,13 +197,13 @@ var openAIClient = OpenAIClientBuilder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .buildClient(); -var embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient) +var embeddingModel = new AzureOpenAiEmbeddingModel(this.openAIClient) .withDefaultOptions(AzureOpenAiEmbeddingOptions.builder() .withModel("text-embedding-ada-002") .withUser("user-6") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc index de0419792..0786bace4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc @@ -172,9 +172,9 @@ var cohereEmbeddingApi =new CohereEmbeddingBedrockApi( EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper()); -var embeddingModel = new BedrockCohereEmbeddingModel(cohereEmbeddingApi); +var embeddingModel = new BedrockCohereEmbeddingModel(this.cohereEmbeddingApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- @@ -202,7 +202,7 @@ CohereEmbeddingRequest request = new CohereEmbeddingRequest( CohereEmbeddingRequest.InputType.search_document, CohereEmbeddingRequest.Truncate.NONE); -CohereEmbeddingResponse response = api.embedding(request); +CohereEmbeddingResponse response = this.api.embedding(this.request); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc index e694649e1..2e0e93adf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc @@ -170,9 +170,9 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp var titanEmbeddingApi = new TitanEmbeddingBedrockApi( TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id()); -var embeddingModel = new BedrockTitanEmbeddingModel(titanEmbeddingApi); +var embeddingModel = new BedrockTitanEmbeddingModel(this.titanEmbeddingApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World")); // NOTE titan does not support batch embedding. ---- @@ -197,7 +197,7 @@ TitanEmbeddingRequest request = TitanEmbeddingRequest.builder() .withInputText("I like to eat apples.") .build(); -TitanEmbeddingResponse response = titanEmbedApi.embedding(request); +TitanEmbeddingResponse response = this.titanEmbedApi.embedding(this.request); ---- To embed an image you need to convert it into `base64` format: @@ -213,8 +213,8 @@ byte[] image = new DefaultResourceLoader() TitanEmbeddingRequest request = TitanEmbeddingRequest.builder() - .withInputImage(Base64.getEncoder().encodeToString(image)) + .withInputImage(Base64.getEncoder().encodeToString(this.image)) .build(); -TitanEmbeddingResponse response = titanEmbedApi.embedding(request); +TitanEmbeddingResponse response = this.titanEmbedApi.embedding(this.request); ---- \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc index b0a917098..114cc6244 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/minimax-embeddings.adoc @@ -183,12 +183,12 @@ Next, create an `MiniMaxEmbeddingModel` instance and use it to compute the simil ---- var miniMaxApi = new MiniMaxApi(System.getenv("MINIMAX_API_KEY")); -var embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi) +var embeddingModel = new MiniMaxEmbeddingModel(this.miniMaxApi) .withDefaultOptions(MiniMaxChatOptions.build() .withModel("embo-01") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc index 23f2781ba..46e9d4953 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/mistralai-embeddings.adoc @@ -184,13 +184,13 @@ Next, create an `MistralAiEmbeddingModel` instance and use it to compute the sim ---- var mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY")); -var embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, +var embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MistralAiEmbeddingOptions.builder() .withModel("mistral-embed") .withEncodingFormat("float") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc index b8d42eee2..d9409675a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/oci-genai-embeddings.adoc @@ -159,15 +159,15 @@ final String REGION = "us-chicago-1"; final String COMPARTMENT_ID = System.getenv("OCI_COMPARTMENT_ID"); var authProvider = new ConfigFileAuthenticationDetailsProvider( - CONFIG_FILE, PROFILE); + this.CONFIG_FILE, this.PROFILE); var aiClient = GenerativeAiInferenceClient.builder() - .region(Region.valueOf(REGION)) - .build(authProvider); + .region(Region.valueOf(this.REGION)) + .build(this.authProvider); var options = OCIEmbeddingOptions.builder() - .withModel(EMBEDDING_MODEL) - .withCompartment(COMPARTMENT_ID) + .withModel(this.EMBEDDING_MODEL) + .withCompartment(this.COMPARTMENT_ID) .withServingMode("on-demand") .build(); -var embeddingModel = new OCIEmbeddingModel(aiClient, options); -List embedding = embeddingModel.embed(new Document("How many provinces are in Canada?")); +var embeddingModel = new OCIEmbeddingModel(this.aiClient, this.options); +List embedding = this.embeddingModel.embed(new Document("How many provinces are in Canada?")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc index 726367892..b522a1694 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc @@ -285,12 +285,12 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd ---- var ollamaApi = new OllamaApi(); -var embeddingModel = new OllamaEmbeddingModel(ollamaApi, +var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi, OllamaOptions.builder() .withModel(OllamaModel.MISTRAL.id()) .build()); -EmbeddingResponse embeddingResponse = embeddingModel.call( +EmbeddingResponse embeddingResponse = this.embeddingModel.call( new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), OllamaOptions.builder() .withModel("chroma/all-minilm-l6-v2-f32")) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc index 0f9fe553a..51a53bb43 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc @@ -176,7 +176,7 @@ embeddingModel.setTokenizerOptions(Map.of("padding", "true")); embeddingModel.afterPropertiesSet(); -List> embeddings = embeddingModel.embed(List.of("Hello world", "World is big")); +List> embeddings = this.embeddingModel.embed(List.of("Hello world", "World is big")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc index 33a93c9a9..c1669acee 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc @@ -196,7 +196,7 @@ Next, create an `OpenAiEmbeddingModel` instance and use it to compute the simila var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); var embeddingModel = new OpenAiEmbeddingModel( - openAiApi, + this.openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder() .withModel("text-embedding-ada-002") @@ -204,7 +204,7 @@ var embeddingModel = new OpenAiEmbeddingModel( .build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc index f3c04e0f2..7fc86f0ab 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/postgresml-embeddings.adoc @@ -155,7 +155,7 @@ PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(this.jdbc embeddingModel.afterPropertiesSet(); // initialize the jdbc template and database. -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc index 260ad6363..31c7c9826 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc @@ -192,7 +192,7 @@ var embeddingClient = new QianFanEmbeddingModel(qianFanApi) .withModel("bge_large_en") .build()); -EmbeddingResponse embeddingResponse = embeddingClient +EmbeddingResponse embeddingResponse = this.embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc index 526d07c39..1f9db4495 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc @@ -127,20 +127,20 @@ VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions. .withModel(VertexAiMultimodalEmbeddingOptions.DEFAULT_MODEL_NAME) .build(); -var embeddingModel = new VertexAiMultimodalEmbeddingModel(connectionDetails, options); +var embeddingModel = new VertexAiMultimodalEmbeddingModel(this.connectionDetails, this.options); Media imageMedial = new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")); Media videoMedial = new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")); -var document = new Document("Explain what do you see on this video?", List.of(imageMedial, videoMedial), Map.of()); +var document = new Document("Explain what do you see on this video?", List.of(this.imageMedial, this.videoMedial), Map.of()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); -DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document), +DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(this.document), EmbeddingOptions.EMPTY); -EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); +EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(this.embeddingRequest); assertThat(embeddingResponse.getResults()).hasSize(3); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc index b81e56d92..63365196f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc @@ -146,9 +146,9 @@ Next, create a `VertexAiPaLm2EmbeddingModel` and use it for text generations: ---- VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); -var embeddingModel = new VertexAiPaLm2EmbeddingModel(vertexAiApi); +var embeddingModel = new VertexAiPaLm2EmbeddingModel(this.vertexAiApi); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- @@ -169,15 +169,15 @@ VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(< YOUR PALM_API_KEY>); // Generate var prompt = new MessagePrompt(List.of(new Message("0", "Hello, how are you?"))); -GenerateMessageRequest request = new GenerateMessageRequest(prompt); +GenerateMessageRequest request = new GenerateMessageRequest(this.prompt); -GenerateMessageResponse response = vertexAiApi.generateMessage(request); +GenerateMessageResponse response = this.vertexAiApi.generateMessage(this.request); // Embed text -Embedding embedding = vertexAiApi.embedText("Hello, how are you?"); +Embedding embedding = this.vertexAiApi.embedText("Hello, how are you?"); // Batch embedding -List embeddings = vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); +List embeddings = this.vertexAiApi.batchEmbedText(List.of("Hello, how are you?", "I am fine, thank you!")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc index 56bdef6d5..839319c52 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc @@ -154,9 +154,9 @@ VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) .build(); -var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, options); +var embeddingModel = new VertexAiTextEmbeddingModel(this.connectionDetails, this.options); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc index 444dede39..0fd6a08d1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc @@ -183,12 +183,12 @@ Next, create an `ZhiPuAiEmbeddingModel` instance and use it to compute the simil ---- var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); -var embeddingModel = new ZhiPuAiEmbeddingModel(zhiPuAiApi) +var embeddingModel = new ZhiPuAiEmbeddingModel(this.zhiPuAiApi) .withDefaultOptions(ZhiPuAiChatOptions.build() .withModel("embedding-2") .build()); -EmbeddingResponse embeddingResponse = embeddingModel +EmbeddingResponse embeddingResponse = this.embeddingModel .embedForResponse(List.of("Hello World", "World is big and salvation is near")); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc index f8daf7b09..06c5808d7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc @@ -123,7 +123,7 @@ class MyJsonReader { } List loadJsonAsDocuments() { - JsonReader jsonReader = new JsonReader(resource, "description", "content"); + JsonReader jsonReader = new JsonReader(this.resource, "description", "content"); return jsonReader.get(); } } @@ -189,7 +189,7 @@ This method allows you to use a JSON Pointer to retrieve a specific part of the [source,java] ---- JsonReader jsonReader = new JsonReader(resource, "description"); -List documents = jsonReader.get("/store/books/0"); +List documents = this.jsonReader.get("/store/books/0"); ---- ==== Example JSON Structure @@ -236,7 +236,7 @@ class MyTextReader { this.resource = resource; } List loadText() { - TextReader textReader = new TextReader(resource); + TextReader textReader = new TextReader(this.resource); textReader.getCustomMetadata().put("filename", "text-source.txt"); return textReader.read(); @@ -281,7 +281,7 @@ The `TextReader` processes text content as follows: [source,java] ---- List documents = textReader.get(); -List splitDocuments = new TokenTextSplitter().apply(documents); +List splitDocuments = new TokenTextSplitter().apply(this.documents); ---- * The reader uses Spring's `Resource` abstraction, allowing it to read from various sources (classpath, file system, URL, etc.). @@ -313,7 +313,7 @@ class MyMarkdownReader { .withAdditionalMetadata("filename", "code.md") .build(); - MarkdownDocumentReader reader = new MarkdownDocumentReader(resource, config); + MarkdownDocumentReader reader = new MarkdownDocumentReader(this.resource, config); return reader.get(); } } @@ -501,7 +501,7 @@ class MyTikaDocumentReader { } List loadText() { - TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(resource); + TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(this.resource); return tikaDocumentReader.read(); } } @@ -576,7 +576,7 @@ Document doc2 = new Document("Another document with content that will be split b Map.of("source", "example2.txt")); TokenTextSplitter splitter = new TokenTextSplitter(); -List splitDocuments = splitter.apply(List.of(doc1, doc2)); +List splitDocuments = this.splitter.apply(List.of(this.doc1, this.doc2)); for (Document doc : splitDocuments) { System.out.println("Chunk: " + doc.getContent()); @@ -612,7 +612,7 @@ class MyKeywordEnricher { } List enrichDocuments(List documents) { - KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5); + KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(this.chatModel, 5); return enricher.apply(documents); } } @@ -655,10 +655,10 @@ KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5); Document doc = new Document("This is a document about artificial intelligence and its applications in modern technology."); -List enrichedDocs = enricher.apply(List.of(doc)); +List enrichedDocs = enricher.apply(List.of(this.doc)); -Document enrichedDoc = enrichedDocs.get(0); -String keywords = (String) enrichedDoc.getMetadata().get("excerpt_keywords"); +Document enrichedDoc = this.enrichedDocs.get(0); +String keywords = (String) this.enrichedDoc.getMetadata().get("excerpt_keywords"); System.out.println("Extracted keywords: " + keywords); ---- @@ -697,7 +697,7 @@ class MySummaryEnricher { } List enrichDocuments(List documents) { - return enricher.apply(documents); + return this.enricher.apply(documents); } } ---- @@ -757,7 +757,7 @@ SummaryMetadataEnricher enricher = new SummaryMetadataEnricher(chatModel, Document doc1 = new Document("Content of document 1"); Document doc2 = new Document("Content of document 2"); -List enrichedDocs = enricher.apply(List.of(doc1, doc2)); +List enrichedDocs = enricher.apply(List.of(this.doc1, this.doc2)); // Check the metadata of the enriched documents for (Document doc : enrichedDocs) { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index aad707611..db7d51693 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -154,7 +154,7 @@ To let the model know and call your `CurrentWeather` function you need to enable ---- ChatClient chatClient = ... -ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") +ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions("CurrentWeather") // Enable the function .call(). chatResponse(); @@ -181,7 +181,7 @@ In addition to the auto-configuration, you can register callback functions, dyna ---- ChatClient chatClient = ... -ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") +ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions(new FunctionCallbackWrapper<>( "CurrentWeather", // name "Get the weather in location", // function description @@ -230,7 +230,7 @@ BiFunction ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallbackWrapper.builder(weatherFunction) + .functions(FunctionCallbackWrapper.builder(this.weatherFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build()) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc index 82fefe854..81c1c9120 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/moderation/openai-moderation.adoc @@ -81,8 +81,8 @@ OpenAiModerationOptions moderationOptions = OpenAiModerationOptions.builder() .withModel("text-moderation-latest") .build(); -ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", moderationOptions); -ModerationResponse response = openAiModerationModel.call(moderationPrompt); +ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", this.moderationOptions); +ModerationResponse response = openAiModerationModel.call(this.moderationPrompt); // Access the moderation results Moderation moderation = moderationResponse.getResult().getOutput(); @@ -97,7 +97,7 @@ for (ModerationResult result : moderation.getResults()) { System.out.println("Flagged: " + result.isFlagged()); // Access categories - Categories categories = result.getCategories(); + Categories categories = this.result.getCategories(); System.out.println("\nCategories:"); System.out.println("Sexual: " + categories.isSexual()); System.out.println("Hate: " + categories.isHate()); @@ -112,7 +112,7 @@ for (ModerationResult result : moderation.getResults()) { System.out.println("Violence: " + categories.isViolence()); // Access category scores - CategoryScores scores = result.getCategoryScores(); + CategoryScores scores = this.result.getCategoryScores(); System.out.println("\nCategory Scores:"); System.out.println("Sexual: " + scores.getSexual()); System.out.println("Hate: " + scores.getHate()); @@ -158,14 +158,14 @@ Next, create an OpenAiModerationModel: ---- OpenAiModerationApi openAiModerationApi = new OpenAiModerationApi(System.getenv("OPENAI_API_KEY")); -OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(openAiModerationApi); +OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(this.openAiModerationApi); OpenAiModerationOptions moderationOptions = OpenAiModerationOptions.builder() .withModel("text-moderation-latest") .build(); -ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", moderationOptions); -ModerationResponse response = openAiModerationModel.call(moderationPrompt); +ModerationPrompt moderationPrompt = new ModerationPrompt("Text to be moderated", this.moderationOptions); +ModerationResponse response = this.openAiModerationModel.call(this.moderationPrompt); ---- == Example Code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc index bbcab25b3..8afc5586d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc @@ -43,9 +43,9 @@ var imageResource = new ClassPathResource("/multimodal.test.png"); var userMessage = new UserMessage( "Explain what do you see in this picture?", // content - new Media(MimeTypeUtils.IMAGE_PNG, imageResource)); // media + new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); // media -ChatResponse response = chatModel.call(new Prompt(userMessage)); +ChatResponse response = chatModel.call(new Prompt(this.userMessage)); ---- or with the fluent xref::api/chatclient.adoc[ChatClient] API: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc index 7e11009f5..6a41bacc5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/prompt.adoc @@ -202,7 +202,7 @@ A simple example taken from the https://github.com/Azure-Samples/spring-ai-azure PromptTemplate promptTemplate = new PromptTemplate("Tell me a {adjective} joke about {topic}"); -Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic)); +Prompt prompt = this.promptTemplate.create(Map.of("adjective", adjective, "topic", topic)); return chatModel.call(prompt).getResult(); ``` @@ -215,7 +215,7 @@ String userText = """ Write at least a sentence for each pirate. """; -Message userMessage = new UserMessage(userText); +Message userMessage = new UserMessage(this.userText); String systemText = """ You are a helpful AI assistant that helps people find information. @@ -223,12 +223,12 @@ String systemText = """ You should reply to the user's request with your name and also in the style of a {voice}. """; -SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText); -Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); +SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemText); +Message systemMessage = this.systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); -Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); +Prompt prompt = new Prompt(List.of(this.userMessage, this.systemMessage)); -List response = chatModel.call(prompt).getResults(); +List response = chatModel.call(this.prompt).getResults(); ``` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc index 90dd0c944..1fc320cd3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc @@ -71,7 +71,7 @@ The format instructions are most often appended to the end of the user input usi """; // user input with a "format" placeholder. Prompt prompt = new Prompt( new PromptTemplate( - userInputTemplate, + this.userInputTemplate, Map.of(..., "format", outputConverter.getFormat()) // replace the "format" placeholder with the converter's format. ).createMessage()); ---- @@ -124,7 +124,7 @@ or using the low-level `ChatModel` API directly: BeanOutputConverter beanOutputConverter = new BeanOutputConverter<>(ActorsFilms.class); -String format = beanOutputConverter.getFormat(); +String format = this.beanOutputConverter.getFormat(); String actor = "Tom Hanks"; @@ -134,9 +134,9 @@ String template = """ """; Generation generation = chatModel.call( - new PromptTemplate(template, Map.of("actor", actor, "format", format)).create()).getResult(); + new PromptTemplate(this.template, Map.of("actor", this.actor, "format", this.format)).create()).getResult(); -ActorsFilms actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent()); +ActorsFilms actorsFilms = this.beanOutputConverter.convert(this.generation.getOutput().getContent()); ---- ==== Generic Bean Types @@ -159,17 +159,17 @@ or using the low-level `ChatModel` API directly: BeanOutputConverter> outputConverter = new BeanOutputConverter<>( new ParameterizedTypeReference>() { }); -String format = outputConverter.getFormat(); +String format = this.outputConverter.getFormat(); String template = """ Generate the filmography of 5 movies for Tom Hanks and Bill Murray. {format} """; -Prompt prompt = new PromptTemplate(template, Map.of("format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, Map.of("format", this.format)).create(); -Generation generation = chatModel.call(prompt).getResult(); +Generation generation = chatModel.call(this.prompt).getResult(); -List actorsFilms = outputConverter.convert(generation.getOutput().getContent()); +List actorsFilms = this.outputConverter.convert(this.generation.getOutput().getContent()); ---- === Map Output Converter @@ -191,18 +191,18 @@ or using the low-level `ChatModel` API directly: ---- MapOutputConverter mapOutputConverter = new MapOutputConverter(); -String format = mapOutputConverter.getFormat(); +String format = this.mapOutputConverter.getFormat(); String template = """ Provide me a List of {subject} {format} """; -Prompt prompt = new PromptTemplate(template, - Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", this.format)).create(); -Generation generation = chatModel.call(prompt).getResult(); +Generation generation = chatModel.call(this.prompt).getResult(); -Map result = mapOutputConverter.convert(generation.getOutput().getContent()); +Map result = this.mapOutputConverter.convert(this.generation.getOutput().getContent()); ---- === List Output Converter @@ -224,18 +224,18 @@ or using the low-level `ChatModel API` directly: ---- ListOutputConverter listOutputConverter = new ListOutputConverter(new DefaultConversionService()); -String format = listOutputConverter.getFormat(); +String format = this.listOutputConverter.getFormat(); String template = """ List five {subject} {format} """; -Prompt prompt = new PromptTemplate(template, - Map.of("subject", "ice cream flavors", "format", format)).create(); +Prompt prompt = new PromptTemplate(this.template, + Map.of("subject", "ice cream flavors", "format", this.format)).create(); -Generation generation = this.chatModel.call(prompt).getResult(); +Generation generation = this.chatModel.call(this.prompt).getResult(); -List list = listOutputConverter.convert(generation.getOutput().getContent()); +List list = this.listOutputConverter.convert(this.generation.getOutput().getContent()); ---- == Supported AI Models diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index 387c18055..9f320f0d1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -167,7 +167,7 @@ Additionally, `TokenCountBatchingStrategy` provides flexibility by allowing you ---- TokenCountEstimator customEstimator = new YourCustomTokenCountEstimator(); TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy( - customEstimator, + this.customEstimator, 8000, // maxInputTokenCount 0.1, // reservePercentage Document.DEFAULT_CONTENT_FORMATTER, @@ -256,7 +256,7 @@ Later, when a user question is passed into the AI model, a similarity search is ```java String question = - List similarDocuments = store.similaritySearch(question); + List similarDocuments = store.similaritySearch(this.question); ``` Additional options can be passed into the `similaritySearch` method to define how many documents to retrieve and a threshold of the similarity search. @@ -282,7 +282,7 @@ A simple example is as follows: [source, java] ---- FilterExpressionBuilder b = new FilterExpressionBuilder(); -Expression expression = b.eq("country", "BG").build(); +Expression expression = this.b.eq("country", "BG").build(); ---- You can build up sophisticated expressions by using the following operators: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc index aecdc6cc0..82677abef 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/apache-cassandra.adoc @@ -177,7 +177,7 @@ or programmatically using the expression DSL: [source,java] ---- Filter.Expression f = new FilterExpressionBuilder() - .and(f.in("country", "UK", "NL"), f.gte("year", 2020)).build(); + .and(f.in("country", "UK", "NL"), this.f.gte("year", 2020)).build(); vectorStore.similaritySearch( SearchRequest.query("The World").withTopK(TOP_K) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc index df080dbb1..2e1cac5af 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc @@ -68,13 +68,13 @@ public class DemoApplication implements CommandLineRunner { public void run(String... args) throws Exception { Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1")); Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); - vectorStore.add(List.of(document1, document2)); - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + this.vectorStore.add(List.of(document1, document2)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); log.info("Search results: {}", results); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); } @Bean @@ -133,15 +133,15 @@ metadata2.put("country", "NL"); metadata2.put("year", 2022); metadata2.put("city", "Amsterdam"); -Document document1 = new Document("1", "A document about the UK", metadata1); -Document document2 = new Document("2", "A document about the Netherlands", metadata2); +Document document1 = new Document("1", "A document about the UK", this.metadata1); +Document document2 = new Document("2", "A document about the Netherlands", this.metadata2); vectorStore.add(List.of(document1, document2)); FilterExpressionBuilder builder = new FilterExpressionBuilder(); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) - .withFilterExpression((builder.in("country", "UK", "NL")).build())); + .withFilterExpression((this.builder.in("country", "UK", "NL")).build())); ---- == Setting up Azure Cosmos DB Vector Store without Auto Configuration @@ -190,9 +190,9 @@ public class DemoApplication implements CommandLineRunner { public void run(String... args) throws Exception { Document document1 = new Document(UUID.randomUUID().toString(), "Sample content1", Map.of("key1", "value1")); Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); log.info("Search results: {}", results); } @@ -216,7 +216,7 @@ public class DemoApplication implements CommandLineRunner { .gatewayMode() .buildAsyncClient(); - return new CosmosDBVectorStore(observationRegistry, null, cosmosClient, config, embeddingModel); + return new CosmosDBVectorStore(observationRegistry, null, cosmosClient, config, this.embeddingModel); } @Bean diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc index 5c477a2ae..0bd1e5558 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc @@ -101,7 +101,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc index 5f2b0cbf7..fab97b3f0 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc @@ -101,7 +101,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[elasticsearchvector-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc index 7cebf79e2..9718debdb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/hana.adoc @@ -183,7 +183,7 @@ public class CricketWorldCupRepository implements HanaVectorRepository, List> splitter = new TokenTextSplitter(); List documents = splitter.apply(reader.get()); log.info("{} documents created from pdf file: {}", documents.size(), pdf.getFilename()); - hanaCloudVectorStore.accept(documents); + this.hanaCloudVectorStore.accept(documents); return ResponseEntity.ok().body(String.format("%d documents created from pdf file: %s", documents.size(), pdf.getFilename())); } @@ -304,7 +304,7 @@ public class CricketWorldCupHanaController { var userMessage = new UserMessage(message); Prompt prompt = new Prompt(List.of(similarDocsMessage, userMessage)); - String generation = chatModel.call(prompt).getResult().getOutput().getContent(); + String generation = this.chatModel.call(prompt).getResult().getOutput().getContent(); log.info("Generation: {}", generation); return Map.of("generation", generation); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc index 939a00aae..7675196e4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc @@ -85,7 +85,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Manual Configuration diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc index b125371a6..af68e343d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc @@ -252,7 +252,7 @@ List results = vectorStore.similaritySearch( .withQuery("learn how to grow things") .withTopK(2) .withSimilarityThreshold(0.5) - .withFilterExpression(b.eq("author", "A").build()) + .withFilterExpression(this.b.eq("author", "A").build()) ); ---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc index d7437d89c..84bf23db1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/opensearch.adoc @@ -109,7 +109,7 @@ List documents = List.of( vectorStore.add(List.of(document)); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc index 70271ae6e..7817d1202 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/oracle.adoc @@ -91,7 +91,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[oracle-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc index e5c58458e..2380f0f1c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pgvector.adoc @@ -125,7 +125,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[pgvector-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc index 31c493ced..4ca3f2fa7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/pinecone.adoc @@ -96,7 +96,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc index 573cc1068..6df3e7ee1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/qdrant.adoc @@ -98,7 +98,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[qdrant-vectorstore-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc index 04d16a93e..cd6337ba7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc @@ -96,7 +96,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc index f9ac86419..6c655601c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/typesense.adoc @@ -89,7 +89,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- === Configuration properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc index 0bfd2e661..5c1ce1be0 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/weaviate.adoc @@ -102,7 +102,7 @@ List documents = List.of( vectorStore.add(documents); // Retrieve documents similar to a query -List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); ---- [[weaviate-vectorstore-properties]] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index c6189f898..ea7f72f75 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -58,7 +58,7 @@ public class OldSimpleAiController { @GetMapping("/ai/simple") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatClient.call(message)); + return Map.of("generation", this.chatClient.call(message)); } } ``` @@ -77,7 +77,7 @@ public class SimpleAiController { @GetMapping("/ai/simple") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", chatModel.call(message)); + return Map.of("generation", this.chatModel.call(message)); } } ``` @@ -109,7 +109,7 @@ class OldSimpleAiController { Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { return Map.of( "generation", - chatClient.call(message) + this.chatClient.call(message) ); } } @@ -131,7 +131,7 @@ class SimpleAiController { Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { return Map.of( "generation", - chatClient.prompt().user(message).call().content() + this.chatClient.prompt().user(message).call().content() ); } } diff --git a/spring-ai-docs/src/main/javadoc/overview.html b/spring-ai-docs/src/main/javadoc/overview.html index 47ee30f0d..7fb094d7a 100644 --- a/spring-ai-docs/src/main/javadoc/overview.html +++ b/spring-ai-docs/src/main/javadoc/overview.html @@ -1,3 +1,19 @@ + +

    diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml index b5008cd65..848ac8983 100644 --- a/spring-ai-retry/pom.xml +++ b/spring-ai-retry/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java index fcd824693..44c405ca6 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; /** diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index bbc584335..53d99b1db 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; import java.io.IOException; @@ -40,21 +41,6 @@ import org.springframework.web.client.ResponseErrorHandler; */ public abstract class RetryUtils { - private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); - - public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(TransientAiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); - public static final ResponseErrorHandler DEFAULT_RESPONSE_ERROR_HANDLER = new ResponseErrorHandler() { @Override @@ -81,4 +67,20 @@ public abstract class RetryUtils { } }; + private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); + + public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(TransientAiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .withListener(new RetryListener() { + + @Override + public void onError(RetryContext context, + RetryCallback callback, Throwable throwable) { + logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); + } + }) + .build(); + } diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java index 94a710484..95b6e37f6 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.retry; /** diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 2a738a8d4..200acce9e 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java index ff311b7ca..a49203d29 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; @@ -39,8 +42,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Thomas Vitale diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java index b83ba4540..2a7778689 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.springframework.ai.anthropic.AnthropicChatModel; @@ -52,12 +53,12 @@ public class AnthropicChatProperties { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java index 3ad6e4ed9..53a74e351 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.springframework.ai.anthropic.api.AnthropicApi; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java index c223713e8..b3b10416f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAudioTranscriptionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions; @@ -36,7 +37,7 @@ public class AzureOpenAiAudioTranscriptionProperties { private AzureOpenAiAudioTranscriptionOptions options = AzureOpenAiAudioTranscriptionOptions.builder().build(); public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -44,7 +45,7 @@ public class AzureOpenAiAudioTranscriptionProperties { } public AzureOpenAiAudioTranscriptionOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiAudioTranscriptionOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index ee686d44c..c40b97461 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.azure.core.credential.TokenCredential; +import com.azure.core.util.ClientOptions; +import com.azure.core.util.Header; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; @@ -39,15 +48,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.core.credential.KeyCredential; -import com.azure.core.credential.TokenCredential; -import com.azure.core.util.ClientOptions; -import com.azure.core.util.Header; - -import io.micrometer.observation.ObservationRegistry; - /** * @author Piotr Olaszewski * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java index 7ae5ebc8d..58521d28c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java index 16a128260..6ede4e8b3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -54,14 +54,14 @@ public class AzureOpenAiConnectionProperties { this.endpoint = endpoint; } - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - public String getApiKey() { return this.apiKey; } + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + public String getOpenAiApiKey() { return this.openAiApiKey; } @@ -71,7 +71,7 @@ public class AzureOpenAiConnectionProperties { } public Map getCustomHeaders() { - return customHeaders; + return this.customHeaders; } public void setCustomHeaders(Map customHeaders) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java index d7eb357ba..eb88e4f6f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingOptions; @@ -39,7 +40,7 @@ public class AzureOpenAiEmbeddingProperties { private MetadataMode metadataMode = MetadataMode.EMBED; public AzureOpenAiEmbeddingOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiEmbeddingOptions options) { @@ -48,7 +49,7 @@ public class AzureOpenAiEmbeddingProperties { } public MetadataMode getMetadataMode() { - return metadataMode; + return this.metadataMode; } public void setMetadataMode(MetadataMode metadataMode) { @@ -57,7 +58,7 @@ public class AzureOpenAiEmbeddingProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java index 26e1ae2c8..4aea1459f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiImageOptionsProperties.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.azure.openai; import org.springframework.ai.azure.openai.AzureOpenAiImageOptions; @@ -24,7 +40,7 @@ public class AzureOpenAiImageOptionsProperties { private AzureOpenAiImageOptions options = AzureOpenAiImageOptions.builder().build(); public AzureOpenAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(AzureOpenAiImageOptions options) { @@ -32,7 +48,7 @@ public class AzureOpenAiImageOptionsProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java index 46ee80405..10db75782 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java index da8d3f34f..6d5338366 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock; -import org.springframework.boot.context.properties.ConfigurationProperties; - import java.time.Duration; +import org.springframework.boot.context.properties.ConfigurationProperties; + /** * Configuration properties for Bedrock AWS connection. * @@ -51,7 +52,7 @@ public class BedrockAwsConnectionProperties { private Duration timeout = Duration.ofMinutes(5L); public String getRegion() { - return region; + return this.region; } public void setRegion(String awsRegion) { @@ -59,7 +60,7 @@ public class BedrockAwsConnectionProperties { } public String getAccessKey() { - return accessKey; + return this.accessKey; } public void setAccessKey(String accessKey) { @@ -67,7 +68,7 @@ public class BedrockAwsConnectionProperties { } public String getSecretKey() { - return secretKey; + return this.secretKey; } public void setSecretKey(String secretKey) { @@ -75,7 +76,7 @@ public class BedrockAwsConnectionProperties { } public Duration getTimeout() { - return timeout; + return this.timeout; } public void setTimeout(Duration timeout) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java index 5d3f2d5fb..77a891283 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java index daee6365b..5f2519600 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic; import java.util.List; @@ -70,7 +71,7 @@ public class BedrockAnthropicChatProperties { } public AnthropicChatOptions getOptions() { - return options; + return this.options; } public void setOptions(AnthropicChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java index ac4f78835..f385c18ac 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java index 96ddc3e06..ef981bba4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -70,7 +71,7 @@ public class BedrockAnthropic3ChatProperties { } public Anthropic3ChatOptions getOptions() { - return options; + return this.options; } public void setOptions(Anthropic3ChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java index 706474ca3..95e1a1896 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java index 0381d591b..723c91ef2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java index 27b1edd44..86ba3f76b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java index 8e6327941..3841fb6f3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java index ac46a66f6..b84fbb112 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,9 @@ package org.springframework.ai.autoconfigure.bedrock.jurrasic2; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; @@ -29,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Jurassic2 Chat Client. @@ -68,4 +69,4 @@ public class BedrockAi21Jurassic2ChatAutoConfiguration { .build(); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java index 183c050bc..6ed3f4618 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java index 59341b7f1..b97204ea3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java index c58742ee4..979e4fae4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index 639085cb7..c6b0e19ae 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java index 4b6df741a..d54327a6d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions; @@ -45,7 +46,7 @@ public class BedrockTitanChatProperties { private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().withTemperature(0.7).build(); public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -53,7 +54,7 @@ public class BedrockTitanChatProperties { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -61,7 +62,7 @@ public class BedrockTitanChatProperties { } public BedrockTitanChatOptions getOptions() { - return options; + return this.options; } public void setOptions(BedrockTitanChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index 37b63e1ab..96c6cfa8c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java index 5136d7575..b0c1e6048 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType; @@ -46,8 +47,12 @@ public class BedrockTitanEmbeddingProperties { */ private InputType inputType = InputType.IMAGE; + public static String getConfigPrefix() { + return CONFIG_PREFIX; + } + public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -55,23 +60,19 @@ public class BedrockTitanEmbeddingProperties { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { this.model = model; } - public static String getConfigPrefix() { - return CONFIG_PREFIX; + public InputType getInputType() { + return this.inputType; } public void setInputType(InputType inputType) { this.inputType = inputType; } - public InputType getInputType() { - return inputType; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java index 9e5a9185f..b1d9f9aa9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ package org.springframework.ai.autoconfigure.chat.client; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClientCustomizer; import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; @@ -33,8 +35,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Scope; -import io.micrometer.observation.ObservationRegistry; - /** * {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}. *

    diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java index 02c55ee52..a59653855 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java index 102c5b745..91065c189 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/client/ChatClientBuilderProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.client; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -58,7 +59,7 @@ public class ChatClientBuilderProperties { private boolean includeInput = false; public boolean isIncludeInput() { - return includeInput; + return this.includeInput; } public void setIncludeInput(boolean includeCompletion) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java index a635c70f6..9cd3c9152 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/CommonChatMemoryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory; /** @@ -24,7 +25,7 @@ public class CommonChatMemoryProperties { private boolean initializeSchema = true; public boolean isInitializeSchema() { - return initializeSchema; + return this.initializeSchema; } public void setInitializeSchema(boolean initializeSchema) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java index b84011dfe..9e5cfac94 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import com.datastax.oss.driver.api.core.CqlSession; + import org.springframework.ai.chat.memory.CassandraChatMemory; import org.springframework.ai.chat.memory.CassandraChatMemoryConfig; - import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java index 91d2252bc..fc0f45e5c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.time.Duration; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,7 +64,7 @@ public class CassandraChatMemoryProperties extends CommonChatMemoryProperties { } public String getAssistantColumn() { - return assistantColumn; + return this.assistantColumn; } public void setAssistantColumn(String assistantColumn) { @@ -70,7 +72,7 @@ public class CassandraChatMemoryProperties extends CommonChatMemoryProperties { } public String getUserColumn() { - return userColumn; + return this.userColumn; } public void setUserColumn(String userColumn) { @@ -79,7 +81,7 @@ public class CassandraChatMemoryProperties extends CommonChatMemoryProperties { @Nullable public Duration getTimeToLiveSeconds() { - return timeToLiveSeconds; + return this.timeToLiveSeconds; } public void setTimeToLiveSeconds(Duration timeToLiveSeconds) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java index 278ebd35c..f439263fd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; +import java.util.List; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.tracing.Tracer; import io.micrometer.tracing.otel.bridge.OtelTracer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.model.ChatModel; @@ -44,8 +48,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import java.util.List; - /** * Auto-configuration for Spring AI chat model observations. * @@ -60,6 +62,16 @@ public class ChatObservationAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(ChatObservationAutoConfiguration.class); + private static void logPromptContentWarning() { + logger.warn( + "You have enabled the inclusion of the prompt content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + + private static void logCompletionWarning() { + logger.warn( + "You have enabled the inclusion of the completion content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + @Bean @ConditionalOnMissingBean @ConditionalOnBean(MeterRegistry.class) @@ -141,14 +153,4 @@ public class ChatObservationAutoConfiguration { } - private static void logPromptContentWarning() { - logger.warn( - "You have enabled the inclusion of the prompt content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - - private static void logCompletionWarning() { - logger.warn( - "You have enabled the inclusion of the completion content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java index 750f49112..cd353ac0b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -44,7 +45,7 @@ public class ChatObservationProperties { private boolean includeErrorLogging = false; public boolean isIncludeCompletion() { - return includeCompletion; + return this.includeCompletion; } public void setIncludeCompletion(boolean includeCompletion) { @@ -52,7 +53,7 @@ public class ChatObservationProperties { } public boolean isIncludePrompt() { - return includePrompt; + return this.includePrompt; } public void setIncludePrompt(boolean includePrompt) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java index 5d159e12a..1a9623e32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java index e2c8f62aa..ba1d61886 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.embedding.observation; import io.micrometer.core.instrument.MeterRegistry; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; import org.springframework.beans.factory.ObjectProvider; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java index 1d7239f59..dfba7c661 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/embedding/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java index ef28fe2b3..eaf74d75f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.huggingface; import org.springframework.ai.huggingface.HuggingfaceChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java index e64fff584..cf844436d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.huggingface; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -44,7 +45,7 @@ public class HuggingfaceChatProperties { private boolean enabled = true; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -52,7 +53,7 @@ public class HuggingfaceChatProperties { } public String getUrl() { - return url; + return this.url; } public void setUrl(String url) { @@ -60,7 +61,7 @@ public class HuggingfaceChatProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java index e54333f1a..a89a42102 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.observation.ImageModelPromptContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfiguration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java index 5663e2854..3e454ee8d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -34,7 +35,7 @@ public class ImageObservationProperties { private boolean includePrompt = false; public boolean isIncludePrompt() { - return includePrompt; + return this.includePrompt; } public void setIncludePrompt(boolean includePrompt) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java index f95f9e63c..019b02af1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/image/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java index e4c549140..6b7c8307c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -40,8 +43,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java index 5ca297949..df7b1f0fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.ai.minimax.MiniMaxChatOptions; @@ -44,7 +45,7 @@ public class MiniMaxChatProperties extends MiniMaxParentProperties { .build(); public MiniMaxChatOptions getOptions() { - return options; + return this.options; } public void setOptions(MiniMaxChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java index 1019e8499..59d5ff0ca 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java index bfdb49174..cb4023331 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java index 1f8f9f6b7..34b98c0c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/minimax/MiniMaxParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; /** @@ -25,7 +26,7 @@ class MiniMaxParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -33,7 +34,7 @@ class MiniMaxParentProperties { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index be572d819..95241a569 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -42,8 +45,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Ricken Bazolo * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java index 9e46cc830..af39bc374 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.ai.mistralai.MistralAiChatOptions; @@ -39,10 +40,6 @@ public class MistralAiChatProperties extends MistralAiParentProperties { private static final Boolean IS_ENABLED = false; - public MistralAiChatProperties() { - super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL); - } - /** * Enable OpenAI chat model. */ @@ -56,6 +53,10 @@ public class MistralAiChatProperties extends MistralAiParentProperties { .withTopP(DEFAULT_TOP_P) .build(); + public MistralAiChatProperties() { + super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL); + } + public MistralAiChatOptions getOptions() { return this.options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java index 0bc7132be..54023eb63 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiCommonProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java index 450ac479a..633f36097 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.springframework.ai.document.MetadataMode; @@ -34,13 +35,13 @@ public class MistralAiEmbeddingProperties extends MistralAiParentProperties { public static final String DEFAULT_ENCODING_FORMAT = "float"; + public MetadataMode metadataMode = MetadataMode.EMBED; + /** * Enable MistralAI embedding model. */ private boolean enabled = true; - public MetadataMode metadataMode = MetadataMode.EMBED; - @NestedConfigurationProperty private MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder() .withModel(DEFAULT_EMBEDDING_MODEL) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java index 31c632af8..f2d398a58 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; /** diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java index 3bd4223f4..cc3877fdd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -38,8 +41,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java index 91918ff00..299648fe1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.springframework.ai.moonshot.MoonshotChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java index 07525a311..953f41379 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotCommonProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java index 54f3f5f3b..ed27cfba4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/moonshot/MoonshotParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; /** diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java index 6de993705..d64c88f9a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import java.nio.file.Paths; @@ -26,26 +27,9 @@ import org.springframework.util.StringUtils; @ConfigurationProperties(OCIConnectionProperties.CONFIG_PREFIX) public class OCIConnectionProperties { - private static final String DEFAULT_PROFILE = "DEFAULT"; - public static final String CONFIG_PREFIX = "spring.ai.oci.genai"; - public enum AuthenticationType { - - FILE("file"), INSTANCE_PRINCIPAL("instance-principal"), WORKLOAD_IDENTITY("workload-identity"), - SIMPLE("simple"); - - private final String authType; - - AuthenticationType(String authType) { - this.authType = authType; - } - - public String getAuthType() { - return this.authType; - } - - } + private static final String DEFAULT_PROFILE = "DEFAULT"; private AuthenticationType authenticationType = AuthenticationType.FILE; @@ -68,7 +52,7 @@ public class OCIConnectionProperties { private String endpoint; public String getRegion() { - return region; + return this.region; } public void setRegion(String region) { @@ -76,7 +60,7 @@ public class OCIConnectionProperties { } public String getPassPhrase() { - return passPhrase; + return this.passPhrase; } public void setPassPhrase(String passPhrase) { @@ -84,7 +68,7 @@ public class OCIConnectionProperties { } public String getPrivateKey() { - return privateKey; + return this.privateKey; } public void setPrivateKey(String privateKey) { @@ -92,7 +76,7 @@ public class OCIConnectionProperties { } public String getFingerprint() { - return fingerprint; + return this.fingerprint; } public void setFingerprint(String fingerprint) { @@ -100,7 +84,7 @@ public class OCIConnectionProperties { } public String getUserId() { - return userId; + return this.userId; } public void setUserId(String userId) { @@ -108,7 +92,7 @@ public class OCIConnectionProperties { } public String getTenantId() { - return tenantId; + return this.tenantId; } public void setTenantId(String tenantId) { @@ -116,7 +100,7 @@ public class OCIConnectionProperties { } public String getFile() { - return file; + return this.file; } public void setFile(String file) { @@ -124,7 +108,7 @@ public class OCIConnectionProperties { } public String getProfile() { - return StringUtils.hasText(profile) ? profile : DEFAULT_PROFILE; + return StringUtils.hasText(this.profile) ? this.profile : DEFAULT_PROFILE; } public void setProfile(String profile) { @@ -132,7 +116,7 @@ public class OCIConnectionProperties { } public AuthenticationType getAuthenticationType() { - return authenticationType; + return this.authenticationType; } public void setAuthenticationType(AuthenticationType authenticationType) { @@ -140,11 +124,28 @@ public class OCIConnectionProperties { } public String getEndpoint() { - return endpoint; + return this.endpoint; } public void setEndpoint(String endpoint) { this.endpoint = endpoint; } + public enum AuthenticationType { + + FILE("file"), INSTANCE_PRINCIPAL("instance-principal"), WORKLOAD_IDENTITY("workload-identity"), + SIMPLE("simple"); + + private final String authType; + + AuthenticationType(String authType) { + this.authType = authType; + } + + public String getAuthType() { + return this.authType; + } + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java index cc2ef2b5f..034836a25 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIEmbeddingModelProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; + import org.springframework.ai.oci.OCIEmbeddingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -39,15 +41,15 @@ public class OCIEmbeddingModelProperties { public OCIEmbeddingOptions getEmbeddingOptions() { return OCIEmbeddingOptions.builder() - .withCompartment(compartment) - .withModel(model) - .withServingMode(servingMode.getMode()) - .withTruncate(truncate) + .withCompartment(this.compartment) + .withModel(this.model) + .withServingMode(this.servingMode.getMode()) + .withTruncate(this.truncate) .build(); } public ServingMode getServingMode() { - return servingMode; + return this.servingMode; } public void setServingMode(ServingMode servingMode) { @@ -55,7 +57,7 @@ public class OCIEmbeddingModelProperties { } public String getCompartment() { - return compartment; + return this.compartment; } public void setCompartment(String compartment) { @@ -63,7 +65,7 @@ public class OCIEmbeddingModelProperties { } public String getModel() { - return model; + return this.model; } public void setModel(String model) { @@ -71,7 +73,7 @@ public class OCIEmbeddingModelProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { @@ -79,7 +81,7 @@ public class OCIEmbeddingModelProperties { } public EmbedTextDetails.Truncate getTruncate() { - return truncate; + return this.truncate; } public void setTruncate(EmbedTextDetails.Truncate truncate) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java index 9e8a64059..681ee71f3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; import java.io.IOException; @@ -27,6 +28,7 @@ import com.oracle.bmc.auth.SimplePrivateKeySupplier; import com.oracle.bmc.auth.okeworkloadidentity.OkeWorkloadIdentityAuthenticationDetailsProvider; import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; import com.oracle.bmc.retrier.RetryConfiguration; + import org.springframework.ai.oci.OCIEmbeddingModel; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -44,6 +46,23 @@ import org.springframework.util.StringUtils; @EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class }) public class OCIGenAiAutoConfiguration { + private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) + throws IOException { + return switch (properties.getAuthenticationType()) { + case FILE -> new ConfigFileAuthenticationDetailsProvider(properties.getFile(), properties.getProfile()); + case INSTANCE_PRINCIPAL -> InstancePrincipalsAuthenticationDetailsProvider.builder().build(); + case WORKLOAD_IDENTITY -> OkeWorkloadIdentityAuthenticationDetailsProvider.builder().build(); + case SIMPLE -> SimpleAuthenticationDetailsProvider.builder() + .userId(properties.getUserId()) + .tenantId(properties.getTenantId()) + .fingerprint(properties.getFingerprint()) + .privateKeySupplier(new SimplePrivateKeySupplier(properties.getPrivateKey())) + .passPhrase(properties.getPassPhrase()) + .region(Region.valueOf(properties.getRegion())) + .build(); + }; + } + @ConditionalOnMissingBean @Bean public GenerativeAiInferenceClient generativeAiInferenceClient(OCIConnectionProperties properties) @@ -70,21 +89,4 @@ public class OCIGenAiAutoConfiguration { return new OCIEmbeddingModel(generativeAiClient, properties.getEmbeddingOptions()); } - private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) - throws IOException { - return switch (properties.getAuthenticationType()) { - case FILE -> new ConfigFileAuthenticationDetailsProvider(properties.getFile(), properties.getProfile()); - case INSTANCE_PRINCIPAL -> InstancePrincipalsAuthenticationDetailsProvider.builder().build(); - case WORKLOAD_IDENTITY -> OkeWorkloadIdentityAuthenticationDetailsProvider.builder().build(); - case SIMPLE -> SimpleAuthenticationDetailsProvider.builder() - .userId(properties.getUserId()) - .tenantId(properties.getTenantId()) - .fingerprint(properties.getFingerprint()) - .privateKeySupplier(new SimplePrivateKeySupplier(properties.getPrivateKey())) - .passPhrase(properties.getPassPhrase()) - .region(Region.valueOf(properties.getRegion())) - .build(); - }; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java index 7cb2299a2..291a056b4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/ServingMode.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.oci.genai; /** @@ -29,7 +30,7 @@ public enum ServingMode { } public String getMode() { - return mode; + return this.mode; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java index 453c23726..9b86be150 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -40,8 +43,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * {@link AutoConfiguration Auto-configuration} for Ollama Chat Client. * @@ -124,6 +125,14 @@ public class OllamaAutoConfiguration { return embeddingModel; } + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails { private final OllamaConnectionProperties properties; @@ -139,12 +148,4 @@ public class OllamaAutoConfiguration { } - @Bean - @ConditionalOnMissingBean - public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { - FunctionCallbackContext manager = new FunctionCallbackContext(); - manager.setApplicationContext(context); - return manager; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java index c106c8c35..ef60a94dc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.ai.ollama.api.OllamaModel; @@ -56,12 +57,12 @@ public class OllamaChatProperties { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java index 6981097c3..9e392486f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java index 160849ed8..46f127e13 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -34,7 +35,7 @@ public class OllamaConnectionProperties { private String baseUrl = "http://localhost:11434"; public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java index 9b21a92d5..b159fe7be 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.springframework.ai.ollama.api.OllamaModel; @@ -56,12 +57,12 @@ public class OllamaEmbeddingProperties { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java index b884404be..54c764f27 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaInitializationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama; -import org.springframework.ai.ollama.management.PullModelStrategy; -import org.springframework.boot.context.properties.ConfigurationProperties; +package org.springframework.ai.autoconfigure.ollama; import java.time.Duration; import java.util.List; +import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.boot.context.properties.ConfigurationProperties; + /** * Ollama initialization configuration properties. * @@ -32,11 +33,6 @@ public class OllamaInitializationProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.init"; - /** - * Whether to pull models at startup-time and how. - */ - private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER; - /** * Chat models initialization settings. */ @@ -47,6 +43,11 @@ public class OllamaInitializationProperties { */ private final ModelTypeInit embedding = new ModelTypeInit(); + /** + * Whether to pull models at startup-time and how. + */ + private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER; + /** * How long to wait for a model to be pulled. */ @@ -58,7 +59,7 @@ public class OllamaInitializationProperties { private int maxRetries = 0; public PullModelStrategy getPullModelStrategy() { - return pullModelStrategy; + return this.pullModelStrategy; } public void setPullModelStrategy(PullModelStrategy pullModelStrategy) { @@ -66,15 +67,15 @@ public class OllamaInitializationProperties { } public ModelTypeInit getChat() { - return chat; + return this.chat; } public ModelTypeInit getEmbedding() { - return embedding; + return this.embedding; } public Duration getTimeout() { - return timeout; + return this.timeout; } public void setTimeout(Duration timeout) { @@ -82,7 +83,7 @@ public class OllamaInitializationProperties { } public int getMaxRetries() { - return maxRetries; + return this.maxRetries; } public void setMaxRetries(int maxRetries) { @@ -103,7 +104,7 @@ public class OllamaInitializationProperties { private List additionalModels = List.of(); public boolean isInclude() { - return include; + return this.include; } public void setInclude(boolean include) { @@ -111,7 +112,7 @@ public class OllamaInitializationProperties { } public List getAdditionalModels() { - return additionalModels; + return this.additionalModels; } public void setAdditionalModels(List additionalModels) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java index 8d583c205..f7038e679 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -57,7 +57,7 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties { .build(); public OpenAiAudioSpeechOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiAudioSpeechOptions options) { @@ -65,7 +65,7 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java index e8546e089..277fed648 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; @@ -44,7 +45,7 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties { .build(); public OpenAiAudioTranscriptionOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiAudioTranscriptionOptions options) { @@ -52,7 +53,7 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties { } public boolean isEnabled() { - return enabled; + return this.enabled; } public void setEnabled(boolean enabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index d594460c5..3436ca326 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import java.util.HashMap; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; import org.jetbrains.annotations.NotNull; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -56,8 +59,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Stefan Vassilev @@ -73,6 +74,36 @@ import io.micrometer.observation.ObservationRegistry; WebClientAutoConfiguration.class }) public class OpenAiAutoConfiguration { + private static @NotNull ResolvedConnectionProperties resolveConnectionProperties( + OpenAiParentProperties commonProperties, OpenAiParentProperties modelProperties, String modelType) { + + String baseUrl = StringUtils.hasText(modelProperties.getBaseUrl()) ? modelProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + String apiKey = StringUtils.hasText(modelProperties.getApiKey()) ? modelProperties.getApiKey() + : commonProperties.getApiKey(); + String projectId = StringUtils.hasText(modelProperties.getProjectId()) ? modelProperties.getProjectId() + : commonProperties.getProjectId(); + String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) + ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); + + Map> connectionHeaders = new HashMap<>(); + if (StringUtils.hasText(projectId)) { + connectionHeaders.put("OpenAI-Project", List.of(projectId)); + } + if (StringUtils.hasText(organizationId)) { + connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); + } + + Assert.hasText(baseUrl, + "OpenAI base URL must be set. Use the connection property: spring.ai.openai.base-url or spring.ai.openai." + + modelType + ".base-url property."); + Assert.hasText(apiKey, + "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." + + modelType + ".api-key property."); + + return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); + } + @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = OpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", @@ -229,37 +260,8 @@ public class OpenAiAutoConfiguration { return manager; } - private static @NotNull ResolvedConnectionProperties resolveConnectionProperties( - OpenAiParentProperties commonProperties, OpenAiParentProperties modelProperties, String modelType) { - - String baseUrl = StringUtils.hasText(modelProperties.getBaseUrl()) ? modelProperties.getBaseUrl() - : commonProperties.getBaseUrl(); - String apiKey = StringUtils.hasText(modelProperties.getApiKey()) ? modelProperties.getApiKey() - : commonProperties.getApiKey(); - String projectId = StringUtils.hasText(modelProperties.getProjectId()) ? modelProperties.getProjectId() - : commonProperties.getProjectId(); - String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) - ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); - - Map> connectionHeaders = new HashMap<>(); - if (StringUtils.hasText(projectId)) { - connectionHeaders.put("OpenAI-Project", List.of(projectId)); - } - if (StringUtils.hasText(organizationId)) { - connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); - } - - Assert.hasText(baseUrl, - "OpenAI base URL must be set. Use the connection property: spring.ai.openai.base-url or spring.ai.openai." - + modelType + ".base-url property."); - Assert.hasText(apiKey, - "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." - + modelType + ".api-key property."); - - return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); - } - private record ResolvedConnectionProperties(String baseUrl, String apiKey, MultiValueMap headers) { + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index e2014de6b..007542e53 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiChatOptions; @@ -26,10 +27,10 @@ public class OpenAiChatProperties extends OpenAiParentProperties { public static final String DEFAULT_CHAT_MODEL = "gpt-4o"; - private static final Double DEFAULT_TEMPERATURE = 0.7; - public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions"; + private static final Double DEFAULT_TEMPERATURE = 0.7; + /** * Enable OpenAI chat model. */ @@ -44,7 +45,7 @@ public class OpenAiChatProperties extends OpenAiParentProperties { .build(); public OpenAiChatOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiChatOptions options) { @@ -60,7 +61,7 @@ public class OpenAiChatProperties extends OpenAiParentProperties { } public String getCompletionsPath() { - return completionsPath; + return this.completionsPath; } public void setCompletionsPath(String completionsPath) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java index b065deb53..e6c6f582d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java index 008a3c18d..7a0e5286f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.document.MetadataMode; @@ -68,7 +69,7 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { } public String getEmbeddingsPath() { - return embeddingsPath; + return this.embeddingsPath; } public void setEmbeddingsPath(String embeddingsPath) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java index 06fb24bf6..7e14567ba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.openai.OpenAiImageOptions; @@ -45,7 +46,7 @@ public class OpenAiImageProperties extends OpenAiParentProperties { private OpenAiImageOptions options = OpenAiImageOptions.builder().withModel(DEFAULT_IMAGE_MODEL).build(); public OpenAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java index d468f591c..d9e709862 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiModerationProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ public class OpenAiModerationProperties extends OpenAiParentProperties { private OpenAiModerationOptions options = OpenAiModerationOptions.builder().build(); public OpenAiModerationOptions getOptions() { - return options; + return this.options; } public void setOptions(OpenAiModerationOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java index 79aa3d833..7516ba844 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; /** @@ -32,7 +33,7 @@ class OpenAiParentProperties { private String organizationId; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -40,7 +41,7 @@ class OpenAiParentProperties { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java index ca30501b5..b2dc3da2f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import org.springframework.ai.postgresml.PostgresMlEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java index 53dba7f93..d67f94405 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import java.util.Map; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java index 5a2efccae..7c7fe359b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -40,8 +43,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java index cd0edcd3d..9208cd53c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.QianFanChatOptions; @@ -44,7 +45,7 @@ public class QianFanChatProperties extends QianFanParentProperties { .build(); public QianFanChatOptions getOptions() { - return options; + return this.options; } public void setOptions(QianFanChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java index ff07a6291..90cb8c7a2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.api.QianFanConstants; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java index 2a2351100..97091f0b6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java index 3d81043ca..5946747c8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.springframework.ai.qianfan.QianFanImageOptions; @@ -44,7 +45,7 @@ public class QianFanImageProperties extends QianFanParentProperties { private QianFanImageOptions options = QianFanImageOptions.builder().withModel(DEFAULT_IMAGE_MODEL).build(); public QianFanImageOptions getOptions() { - return options; + return this.options; } public void setOptions(QianFanImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java index 109cc279b..543bec0c6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; /** @@ -27,7 +28,7 @@ class QianFanParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -35,7 +36,7 @@ class QianFanParentProperties { } public String getSecretKey() { - return secretKey; + return this.secretKey; } public void setSecretKey(String secretKey) { @@ -43,7 +44,7 @@ class QianFanParentProperties { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java index 295ce7ca4..094120036 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import java.io.IOException; @@ -57,11 +58,14 @@ public class SpringAiRetryAutoConfiguration { .exponentialBackoff(properties.getBackoff().getInitialInterval(), properties.getBackoff().getMultiplier(), properties.getBackoff().getMaxInterval()) .withListener(new RetryListener() { + @Override public void onError(RetryContext context, RetryCallback callback, Throwable throwable) { logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; + } + + ; }) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java index 8b04b81e2..69f651794 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import java.time.Duration; @@ -59,52 +60,6 @@ public class SpringAiRetryProperties { */ private List onHttpCodes = new ArrayList<>(); - /** - * Exponential Backoff properties. - */ - public static class Backoff { - - /** - * Initial sleep duration. - */ - private Duration initialInterval = Duration.ofMillis(2000); - - /** - * Backoff interval multiplier. - */ - private int multiplier = 5; - - /** - * Maximum backoff duration. - */ - private Duration maxInterval = Duration.ofMillis(3 * 60000); - - public Duration getInitialInterval() { - return initialInterval; - } - - public void setInitialInterval(Duration initialInterval) { - this.initialInterval = initialInterval; - } - - public int getMultiplier() { - return multiplier; - } - - public void setMultiplier(int multiplier) { - this.multiplier = multiplier; - } - - public Duration getMaxInterval() { - return maxInterval; - } - - public void setMaxInterval(Duration maxInterval) { - this.maxInterval = maxInterval; - } - - } - public int getMaxAttempts() { return this.maxAttempts; } @@ -141,4 +96,50 @@ public class SpringAiRetryProperties { this.onHttpCodes = onHttpCodes; } + /** + * Exponential Backoff properties. + */ + public static class Backoff { + + /** + * Initial sleep duration. + */ + private Duration initialInterval = Duration.ofMillis(2000); + + /** + * Backoff interval multiplier. + */ + private int multiplier = 5; + + /** + * Maximum backoff duration. + */ + private Duration maxInterval = Duration.ofMillis(3 * 60000); + + public Duration getInitialInterval() { + return this.initialInterval; + } + + public void setInitialInterval(Duration initialInterval) { + this.initialInterval = initialInterval; + } + + public int getMultiplier() { + return this.multiplier; + } + + public void setMultiplier(int multiplier) { + this.multiplier = multiplier; + } + + public Duration getMaxInterval() { + return this.maxInterval; + } + + public void setMaxInterval(Duration maxInterval) { + this.maxInterval = maxInterval; + } + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java index 1cf0d5571..e39d36e7f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiApi; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java index 0a594ff0b..cf5f66cf6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.StabilityAiImageModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java index d307750df..9af35a1f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; @@ -36,9 +37,10 @@ public class StabilityAiImageProperties extends StabilityAiParentProperties { @NestedConfigurationProperty private StabilityAiImageOptions options = StabilityAiImageOptions.builder().build(); // stable-diffusion-v1-6 - // is - // default - // model + + // is + // default + // model public StabilityAiImageOptions getOptions() { return this.options; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java index f8b62cd8e..b62d9e5e3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; /** @@ -28,7 +29,7 @@ class StabilityAiParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -36,7 +37,7 @@ class StabilityAiParentProperties { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java index 583d56863..482e733f9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.onnxruntime.OrtSession; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.beans.factory.ObjectProvider; @@ -25,10 +30,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.onnxruntime.OrtSession; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java index 2ffc590be..67e0922b6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; import java.io.File; @@ -42,11 +43,33 @@ public class TransformersEmbeddingModelProperties { "spring-ai-onnx-generative") .getAbsolutePath(); + @NestedConfigurationProperty + private final Tokenizer tokenizer = new Tokenizer(); + + /** + * Controls caching of remote, large resources to local file system. + */ + @NestedConfigurationProperty + private final Cache cache = new Cache(); + + @NestedConfigurationProperty + private final Onnx onnx = new Onnx(); + /** * Enable the Transformer Embedding model. */ private boolean enabled = true; + /** + * Specifies what parts of the {@link Document}'s content and metadata will be used + * for computing the embeddings. Applicable for the + * {@link TransformersEmbeddingModel#embed(Document)} method only. Has no effect on + * the {@link TransformersEmbeddingModel#embed(String)} or + * {@link TransformersEmbeddingModel#embed(List)}. Defaults to + * {@link MetadataMode#NONE}. + */ + private MetadataMode metadataMode = MetadataMode.NONE; + public boolean isEnabled() { return this.enabled; } @@ -55,6 +78,26 @@ public class TransformersEmbeddingModelProperties { this.enabled = enabled; } + public Cache getCache() { + return this.cache; + } + + public Onnx getOnnx() { + return this.onnx; + } + + public Tokenizer getTokenizer() { + return this.tokenizer; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + /** * Configurations for the {@link HuggingFaceTokenizer} used to convert sentences into * tokens. @@ -93,9 +136,6 @@ public class TransformersEmbeddingModelProperties { } - @NestedConfigurationProperty - private final Tokenizer tokenizer = new Tokenizer(); - public static class Cache { /** @@ -128,16 +168,6 @@ public class TransformersEmbeddingModelProperties { } - /** - * Controls caching of remote, large resources to local file system. - */ - @NestedConfigurationProperty - private final Cache cache = new Cache(); - - public Cache getCache() { - return this.cache; - } - public static class Onnx { /** @@ -186,33 +216,4 @@ public class TransformersEmbeddingModelProperties { } - @NestedConfigurationProperty - private final Onnx onnx = new Onnx(); - - public Onnx getOnnx() { - return this.onnx; - } - - /** - * Specifies what parts of the {@link Document}'s content and metadata will be used - * for computing the embeddings. Applicable for the - * {@link TransformersEmbeddingModel#embed(Document)} method only. Has no effect on - * the {@link TransformersEmbeddingModel#embed(String)} or - * {@link TransformersEmbeddingModel#embed(List)}. Defaults to - * {@link MetadataMode#NONE}. - */ - private MetadataMode metadataMode = MetadataMode.NONE; - - public Tokenizer getTokenizer() { - return this.tokenizer; - } - - public MetadataMode getMetadataMode() { - return this.metadataMode; - } - - public void setMetadataMode(MetadataMode metadataMode) { - this.metadataMode = metadataMode; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java index 5fd20bb55..db3d6e5b9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/CommonVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore; /** @@ -30,7 +31,7 @@ public class CommonVectorStoreProperties { private boolean initializeSchema = false; public boolean isInitializeSchema() { - return initializeSchema; + return this.initializeSchema; } public void setInitializeSchema(boolean initializeSchema) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java index f38bda9dd..1e0d6bfeb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,14 @@ package org.springframework.ai.autoconfigure.vectorstore.azure; +import java.util.List; + import com.azure.core.credential.AzureKeyCredential; import com.azure.core.util.ClientOptions; import com.azure.search.documents.indexes.SearchIndexClient; import com.azure.search.documents.indexes.SearchIndexClientBuilder; - import io.micrometer.observation.ObservationRegistry; -import java.util.List; - import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java index 6de3a5730..661807c7b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.azure; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -38,7 +39,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { private double defaultSimilarityThreshold = -1; public String getUrl() { - return url; + return this.url; } public void setUrl(String endpointUrl) { @@ -46,7 +47,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { } public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -54,7 +55,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { } public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -62,7 +63,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { } public int getDefaultTopK() { - return defaultTopK; + return this.defaultTopK; } public void setDefaultTopK(int defaultTopK) { @@ -70,7 +71,7 @@ public class AzureVectorStoreProperties extends CommonVectorStoreProperties { } public double getDefaultSimilarityThreshold() { - return defaultSimilarityThreshold; + return this.defaultSimilarityThreshold; } public void setDefaultSimilarityThreshold(double defaultSimilarityThreshold) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 0431133b9..f500e1bdf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +20,6 @@ import java.time.Duration; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; - import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.BatchingStrategy; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java index f84a2f0e4..18be88e5b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.cassandra; import com.google.api.client.util.Preconditions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java index b9d84ec00..4f651278a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaApiProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -36,7 +37,7 @@ public class ChromaApiProperties { private String password; public String getHost() { - return host; + return this.host; } public void setHost(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java index 465086d34..58966d336 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 441cb04ad..0cf7aba5a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.chroma; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -31,10 +34,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java index 1768ed81d..42d4edca0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -31,7 +32,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties { private String collectionName = ChromaVectorStore.DEFAULT_COLLECTION_NAME; public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(String collectionName) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java index 8fdbb1ad4..dd77e5341 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.cosmosdb; +import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosClientBuilder; +import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,8 +32,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import com.azure.cosmos.CosmosAsyncClient; -import io.micrometer.observation.ObservationRegistry; /** * @author Theo van Kraay diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java index d7d06ac25..ac716cbd7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 3ca7a399b..3978f1449 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; +import io.micrometer.observation.ObservationRegistry; import org.elasticsearch.client.RestClient; import org.springframework.ai.embedding.BatchingStrategy; @@ -33,8 +34,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; - /** * @author Eddú Meléndez * @author Wei Jiang diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java index ba0100bb3..336a8f57e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -52,7 +53,7 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert } public Integer getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(Integer dimensions) { @@ -60,7 +61,7 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert } public SimilarityFunction getSimilarity() { - return similarity; + return this.similarity; } public void setSimilarity(SimilarityFunction similarity) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java index 32a3a7486..2013f9f60 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.gemfire; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java index f9ed2f8ce..32e4fada4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -29,8 +31,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geet Rawat * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java index 2e7e89cfe..5e650dc92 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java @@ -94,7 +94,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { private boolean sslEnabled = GemFireVectorStore.GemFireVectorStoreConfig.DEFAULT_SSL_ENABLED; public int getBeamWidth() { - return beamWidth; + return this.beamWidth; } public void setBeamWidth(int beamWidth) { @@ -102,7 +102,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -110,7 +110,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -118,7 +118,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -126,7 +126,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public int getMaxConnections() { - return maxConnections; + return this.maxConnections; } public void setMaxConnections(int maxConnections) { @@ -134,7 +134,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public String getVectorSimilarityFunction() { - return vectorSimilarityFunction; + return this.vectorSimilarityFunction; } public void setVectorSimilarityFunction(String vectorSimilarityFunction) { @@ -142,7 +142,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public String[] getFields() { - return fields; + return this.fields; } public void setFields(String[] fields) { @@ -150,7 +150,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public int getBuckets() { - return buckets; + return this.buckets; } public void setBuckets(int buckets) { @@ -158,7 +158,7 @@ public class GemFireVectorStoreProperties extends CommonVectorStoreProperties { } public boolean isSslEnabled() { - return sslEnabled; + return this.sslEnabled; } public void setSslEnabled(boolean sslEnabled) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java index 21cb9b9aa..a2dda1e7a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.HanaCloudVectorStore; import org.springframework.ai.vectorstore.HanaCloudVectorStoreConfig; @@ -31,8 +34,6 @@ import org.springframework.boot.autoconfigure.data.jpa.JpaRepositoriesAutoConfig import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Rahul Mittal * @author Christian Tzolov @@ -59,4 +60,4 @@ public class HanaCloudVectorStoreAutoConfiguration { customObservationConvention.getIfAvailable(() -> null)); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java index 79dfbfbc1..4c30d2407 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -31,7 +32,7 @@ public class HanaCloudVectorStoreProperties { private int topK; public String getTableName() { - return tableName; + return this.tableName; } public void setTableName(String tableName) { @@ -39,7 +40,7 @@ public class HanaCloudVectorStoreProperties { } public int getTopK() { - return topK; + return this.topK; } public void setTopK(int topK) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java index b6d015630..ffc93daed 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java index ecd3ebd51..9d677aa2f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusServiceClientProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import java.util.concurrent.TimeUnit; @@ -29,6 +30,11 @@ public class MilvusServiceClientProperties { public static final String CONFIG_PREFIX = "spring.ai.vectorstore.milvus.client"; + /** + * Secure the authorization for this connection, set to True to enable TLS. + */ + protected boolean secure = false; + /** * Milvus host name/address. */ @@ -61,17 +67,17 @@ public class MilvusServiceClientProperties { */ private long keepAliveTimeMs = 55000; + /** + * Enables the keep-alive function for client channel. + */ + // private boolean keepAliveWithoutCalls = false; + /** * The keep-alive timeout value of client channel. The timeout value must be greater * than zero. */ private long keepAliveTimeoutMs = 20000; - /** - * Enables the keep-alive function for client channel. - */ - // private boolean keepAliveWithoutCalls = false; - /** * Deadline for how long you are willing to wait for a reply from the server. With a * deadline setting, the client will wait when encounter fast RPC fail caused by @@ -110,11 +116,6 @@ public class MilvusServiceClientProperties { */ private String serverName; - /** - * Secure the authorization for this connection, set to True to enable TLS. - */ - protected boolean secure = false; - /** * Idle timeout value of client channel. The timeout value must be larger than zero. */ @@ -131,7 +132,7 @@ public class MilvusServiceClientProperties { private String password = "milvus"; public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -139,7 +140,7 @@ public class MilvusServiceClientProperties { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -147,7 +148,7 @@ public class MilvusServiceClientProperties { } public String getUri() { - return uri; + return this.uri; } public void setUri(String uri) { @@ -155,7 +156,7 @@ public class MilvusServiceClientProperties { } public String getToken() { - return token; + return this.token; } public void setToken(String token) { @@ -163,7 +164,7 @@ public class MilvusServiceClientProperties { } public long getConnectTimeoutMs() { - return connectTimeoutMs; + return this.connectTimeoutMs; } public void setConnectTimeoutMs(long connectTimeoutMs) { @@ -171,7 +172,7 @@ public class MilvusServiceClientProperties { } public long getKeepAliveTimeMs() { - return keepAliveTimeMs; + return this.keepAliveTimeMs; } public void setKeepAliveTimeMs(long keepAliveTimeMs) { @@ -179,7 +180,7 @@ public class MilvusServiceClientProperties { } public long getKeepAliveTimeoutMs() { - return keepAliveTimeoutMs; + return this.keepAliveTimeoutMs; } public void setKeepAliveTimeoutMs(long keepAliveTimeoutMs) { @@ -195,7 +196,7 @@ public class MilvusServiceClientProperties { // } public long getRpcDeadlineMs() { - return rpcDeadlineMs; + return this.rpcDeadlineMs; } public void setRpcDeadlineMs(long rpcDeadlineMs) { @@ -203,7 +204,7 @@ public class MilvusServiceClientProperties { } public String getClientKeyPath() { - return clientKeyPath; + return this.clientKeyPath; } public void setClientKeyPath(String clientKeyPath) { @@ -211,7 +212,7 @@ public class MilvusServiceClientProperties { } public String getClientPemPath() { - return clientPemPath; + return this.clientPemPath; } public void setClientPemPath(String clientPemPath) { @@ -219,7 +220,7 @@ public class MilvusServiceClientProperties { } public String getCaPemPath() { - return caPemPath; + return this.caPemPath; } public void setCaPemPath(String caPemPath) { @@ -227,7 +228,7 @@ public class MilvusServiceClientProperties { } public String getServerPemPath() { - return serverPemPath; + return this.serverPemPath; } public void setServerPemPath(String serverPemPath) { @@ -235,7 +236,7 @@ public class MilvusServiceClientProperties { } public String getServerName() { - return serverName; + return this.serverName; } public void setServerName(String serverName) { @@ -243,7 +244,7 @@ public class MilvusServiceClientProperties { } public boolean isSecure() { - return secure; + return this.secure; } public void setSecure(boolean secure) { @@ -251,7 +252,7 @@ public class MilvusServiceClientProperties { } public long getIdleTimeoutMs() { - return idleTimeoutMs; + return this.idleTimeoutMs; } public void setIdleTimeoutMs(long idleTimeoutMs) { @@ -259,7 +260,7 @@ public class MilvusServiceClientProperties { } public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { @@ -267,7 +268,7 @@ public class MilvusServiceClientProperties { } public String getPassword() { - return password; + return this.password; } public void setPassword(String password) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java index ec789bed8..e22d53b99 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; +import java.util.concurrent.TimeUnit; + import io.micrometer.observation.ObservationRegistry; import io.milvus.client.MilvusServiceClient; import io.milvus.param.ConnectParam; import io.milvus.param.IndexType; import io.milvus.param.MetricType; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -34,8 +38,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import java.util.concurrent.TimeUnit; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java index 2a4b82864..9a17543b5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.milvus; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -58,6 +59,60 @@ public class MilvusVectorStoreProperties extends CommonVectorStoreProperties { */ private String indexParameters = "{\"nlist\":1024}"; + public String getDatabaseName() { + return this.databaseName; + } + + public void setDatabaseName(String databaseName) { + Assert.hasText(databaseName, "Database name should not be empty."); + this.databaseName = databaseName; + } + + public String getCollectionName() { + return this.collectionName; + } + + public void setCollectionName(String collectionName) { + Assert.hasText(collectionName, "Collection name should not be empty."); + this.collectionName = collectionName; + } + + public int getEmbeddingDimension() { + return this.embeddingDimension; + } + + public void setEmbeddingDimension(int embeddingDimension) { + Assert.isTrue(embeddingDimension > 0, "Embedding dimension should be a positive value."); + this.embeddingDimension = embeddingDimension; + } + + public MilvusIndexType getIndexType() { + return this.indexType; + } + + public void setIndexType(MilvusIndexType indexType) { + Assert.notNull(indexType, "Index type can not be null"); + this.indexType = indexType; + } + + public MilvusMetricType getMetricType() { + return this.metricType; + } + + public void setMetricType(MilvusMetricType metricType) { + Assert.notNull(metricType, "MetricType can not be null"); + this.metricType = metricType; + } + + public String getIndexParameters() { + return this.indexParameters; + } + + public void setIndexParameters(String indexParameters) { + Assert.notNull(indexParameters, "indexParameters can not be null"); + this.indexParameters = indexParameters; + } + public enum MilvusMetricType { /** @@ -94,58 +149,4 @@ public class MilvusVectorStoreProperties extends CommonVectorStoreProperties { } - public String getDatabaseName() { - return databaseName; - } - - public void setDatabaseName(String databaseName) { - Assert.hasText(databaseName, "Database name should not be empty."); - this.databaseName = databaseName; - } - - public String getCollectionName() { - return collectionName; - } - - public void setCollectionName(String collectionName) { - Assert.hasText(collectionName, "Collection name should not be empty."); - this.collectionName = collectionName; - } - - public int getEmbeddingDimension() { - return embeddingDimension; - } - - public void setEmbeddingDimension(int embeddingDimension) { - Assert.isTrue(embeddingDimension > 0, "Embedding dimension should be a positive value."); - this.embeddingDimension = embeddingDimension; - } - - public MilvusIndexType getIndexType() { - return indexType; - } - - public void setIndexType(MilvusIndexType indexType) { - Assert.notNull(indexType, "Index type can not be null"); - this.indexType = indexType; - } - - public MilvusMetricType getMetricType() { - return metricType; - } - - public void setMetricType(MilvusMetricType metricType) { - Assert.notNull(metricType, "MetricType can not be null"); - this.metricType = metricType; - } - - public String getIndexParameters() { - return indexParameters; - } - - public void setIndexParameters(String indexParameters) { - Assert.notNull(indexParameters, "indexParameters can not be null"); - this.indexParameters = indexParameters; - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java index f9053d0bb..59f5855a7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,10 @@ package org.springframework.ai.autoconfigure.vectorstore.mongo; +import java.util.Arrays; + +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -33,10 +37,6 @@ import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.util.MimeType; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; - -import java.util.Arrays; - /** * @author Eddú Meléndez * @author Christian Tzolov @@ -86,6 +86,7 @@ public class MongoDBAtlasVectorStoreAutoConfiguration { @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -96,6 +97,7 @@ public class MongoDBAtlasVectorStoreAutoConfiguration { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java index 683337464..22b6c680e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.mongo; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; -import java.util.List; - /** * @author Eddú Meléndez * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java index 3faaa3b64..e5be310fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.neo4j; +import io.micrometer.observation.ObservationRegistry; import org.neo4j.driver.Driver; import org.springframework.ai.embedding.BatchingStrategy; @@ -31,8 +32,6 @@ import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Jingzhou Ou * @author Josh Long diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java index 5d782b8a9..53e75b46a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.neo4j; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java index c37c22389..b69a48e1b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.observation; import io.micrometer.tracing.otel.bridge.OtelTracer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationFilter; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler; @@ -46,6 +48,11 @@ public class VectorStoreObservationAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(VectorStoreObservationAutoConfiguration.class); + private static void logQueryResponseContentWarning() { + logger.warn( + "You have enabled the inclusion of the query response content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); + } + /** * The query response content is typically too big to be included in an observation as * span attributes. That's why the preferred way to store it is as span events, which @@ -84,9 +91,4 @@ public class VectorStoreObservationAutoConfiguration { } - private static void logQueryResponseContentWarning() { - logger.warn( - "You have enabled the inclusion of the query response content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java index 33589abad..423ca5719 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.observation; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java index 347dd6a3e..af2a6feec 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/observation/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java index 39c3b4c34..1b6b3d207 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; -import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; - import java.util.List; +import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; + public interface OpenSearchConnectionDetails extends ConnectionDetails { List getUris(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java index fe19c2552..78fc694a7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,11 @@ package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Optional; + +import io.micrometer.observation.ObservationRegistry; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -25,6 +30,11 @@ import org.opensearch.client.transport.OpenSearchTransport; import org.opensearch.client.transport.aws.AwsSdk2Transport; import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,17 +50,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import io.micrometer.observation.ObservationRegistry; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; - -import java.net.URISyntaxException; -import java.util.List; -import java.util.Optional; - @AutoConfiguration @ConditionalOnClass({ OpenSearchVectorStore.class, EmbeddingModel.class, OpenSearchClient.class }) @EnableConfigurationProperties(OpenSearchVectorStoreProperties.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java index a50c02ef6..a8b9f4e7e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; -import java.util.List; - @ConfigurationProperties(prefix = OpenSearchVectorStoreProperties.CONFIG_PREFIX) public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties { @@ -41,7 +42,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties private Aws aws = new Aws(); public List getUris() { - return uris; + return this.uris; } public void setUris(List uris) { @@ -57,7 +58,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties } public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { @@ -65,7 +66,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties } public String getPassword() { - return password; + return this.password; } public void setPassword(String password) { @@ -73,7 +74,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties } public String getMappingJson() { - return mappingJson; + return this.mappingJson; } public void setMappingJson(String mappingJson) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java index 837e278c2..a63c95bac 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,8 @@ package org.springframework.ai.autoconfigure.vectorstore.oracle; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,8 +34,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; -import io.micrometer.observation.ObservationRegistry; - /** * @author Loïc Lefèvre * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java index 2d5eb2f7f..27fd396c4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.oracle; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -44,7 +45,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { private int searchAccuracy = DEFAULT_SEARCH_ACCURACY; public String getTableName() { - return tableName; + return this.tableName; } public void setTableName(String tableName) { @@ -52,7 +53,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public OracleVectorStore.OracleVectorStoreIndexType getIndexType() { - return indexType; + return this.indexType; } public void setIndexType(OracleVectorStore.OracleVectorStoreIndexType indexType) { @@ -60,7 +61,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public OracleVectorStore.OracleVectorStoreDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } public void setDistanceType(OracleVectorStore.OracleVectorStoreDistanceType distanceType) { @@ -68,7 +69,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dimensions) { @@ -76,7 +77,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public boolean isRemoveExistingVectorStoreTable() { - return removeExistingVectorStoreTable; + return this.removeExistingVectorStoreTable; } public void setRemoveExistingVectorStoreTable(boolean removeExistingVectorStoreTable) { @@ -84,7 +85,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public boolean isForcedNormalization() { - return forcedNormalization; + return this.forcedNormalization; } public void setForcedNormalization(boolean forcedNormalization) { @@ -92,7 +93,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { } public int getSearchAccuracy() { - return searchAccuracy; + return this.searchAccuracy; } public void setSearchAccuracy(int searchAccuracy) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index ec4d76e07..8f9cf21b4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,8 @@ package org.springframework.ai.autoconfigure.vectorstore.pgvector; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,8 +34,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Josh Long diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index 47a12c36d..d2947a5ad 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pgvector; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; @@ -49,7 +50,7 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dimensions) { @@ -57,7 +58,7 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { } public PgIndexType getIndexType() { - return indexType; + return this.indexType; } public void setIndexType(PgIndexType createIndexMethod) { @@ -65,7 +66,7 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { } public PgDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } public void setDistanceType(PgDistanceType distanceType) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java index 058b62841..9526c3e3c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.vectorstore.pinecone; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -29,8 +31,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import io.micrometer.observation.ObservationRegistry; - /** * @author Christian Tzolov * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java index 3ba28c228..c73dfee07 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.time.Duration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java index e5cf97fb7..321d21b8e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java index d1cb23790..3d1914200 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java index 880f6925d..10438c5a9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 92631831b..2ffaf86a2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.redis; +import io.micrometer.observation.ObservationRegistry; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -32,9 +35,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import io.micrometer.observation.ObservationRegistry; -import redis.clients.jedis.JedisPooled; - /** * @author Christian Tzolov * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java index 4799afb80..9a1922607 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.redis; import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java index 48d6b6cc3..94f0fd102 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.typesense; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java index 72e6f6a9f..bc4ab0dae 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseServiceClientProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -39,7 +39,7 @@ public class TypesenseServiceClientProperties { private String apiKey = "xyz"; public String getProtocol() { - return protocol; + return this.protocol; } public void setProtocol(String protocol) { @@ -47,7 +47,7 @@ public class TypesenseServiceClientProperties { } public String getHost() { - return host; + return this.host; } public void setHost(String host) { @@ -55,7 +55,7 @@ public class TypesenseServiceClientProperties { } public int getPort() { - return port; + return this.port; } public void setPort(int port) { @@ -63,7 +63,7 @@ public class TypesenseServiceClientProperties { } public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java index de6e9a490..14789133d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,15 @@ package org.springframework.ai.autoconfigure.vectorstore.typesense; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import io.micrometer.observation.ObservationRegistry; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -28,15 +37,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; - -import io.micrometer.observation.ObservationRegistry; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; /** * @author Pablo Sanchidrian Herrera diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java index ddf74de4d..22eea6396 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreProperties.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -40,7 +40,7 @@ public class TypesenseVectorStoreProperties extends CommonVectorStoreProperties private int embeddingDimension = TypesenseVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE; public String getCollectionName() { - return collectionName; + return this.collectionName; } public void setCollectionName(String collectionName) { @@ -48,7 +48,7 @@ public class TypesenseVectorStoreProperties extends CommonVectorStoreProperties } public int getEmbeddingDimension() { - return embeddingDimension; + return this.embeddingDimension; } public void setEmbeddingDimension(int embeddingDimension) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java index 154271c6f..5981040d3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateConnectionDetails.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java index 9ca3899db..f16226bd7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java index d0793de48..7b3a61f16 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import java.util.Map; -import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; @@ -48,24 +48,24 @@ public class WeaviateVectorStoreProperties { private Map headers = Map.of(); + public String getScheme() { + return this.scheme; + } + public void setScheme(String scheme) { this.scheme = scheme; } - public String getScheme() { - return scheme; + public String getHost() { + return this.host; } public void setHost(String host) { this.host = host; } - public String getHost() { - return host; - } - public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -73,7 +73,7 @@ public class WeaviateVectorStoreProperties { } public String getObjectClass() { - return objectClass; + return this.objectClass; } public void setObjectClass(String indexName) { @@ -81,7 +81,7 @@ public class WeaviateVectorStoreProperties { } public ConsistentLevel getConsistencyLevel() { - return consistencyLevel; + return this.consistencyLevel; } public void setConsistencyLevel(ConsistentLevel consistencyLevel) { @@ -89,7 +89,7 @@ public class WeaviateVectorStoreProperties { } public Map getHeaders() { - return headers; + return this.headers; } public void setHeaders(Map headers) { @@ -97,7 +97,7 @@ public class WeaviateVectorStoreProperties { } public Map getFilterField() { - return filterField; + return this.filterField; } public void setFilterField(Map filterMetadataFields) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java index b51c0e718..7549cb727 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import java.io.IOException; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; @@ -34,10 +38,6 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.google.cloud.vertexai.VertexAI; - -import io.micrometer.observation.ObservationRegistry; - /** * Auto-configuration for Vertex AI Gemini Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java index 0073f5690..a86462d39 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java index 6d08403f5..a47488b92 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java index 102548521..26073283f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.embedding; import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index a6c83100a..b73332e81 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.io.IOException; import java.util.List; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; @@ -38,11 +43,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.cloud.vertexai.VertexAI; - -import io.micrometer.observation.ObservationRegistry; - /** * Auto-configuration for Vertex AI Gemini Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java index 4e9572c15..22d25ab24 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java index ef65327b5..e47d41863 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.util.List; @@ -58,18 +59,6 @@ public class VertexAiGeminiConnectionProperties { private Transport transport = Transport.GRPC; - public enum Transport { - - /** When used, the clients will send REST requests to the backing service. */ - REST, - /** - * When used, the clients will send gRPC to the backing service. This is usually - * more efficient and is the default transport. - */ - GRPC - - } - public String getProjectId() { return this.projectId; } @@ -98,6 +87,10 @@ public class VertexAiGeminiConnectionProperties { return this.apiEndpoint; } + public void setApiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + } + public List getScopes() { return this.scopes; } @@ -106,10 +99,6 @@ public class VertexAiGeminiConnectionProperties { this.scopes = scopes; } - public void setApiEndpoint(String apiEndpoint) { - this.apiEndpoint = apiEndpoint; - } - public Transport getTransport() { return this.transport; } @@ -118,4 +107,16 @@ public class VertexAiGeminiConnectionProperties { this.transport = transport; } + public enum Transport { + + /** When used, the clients will send REST requests to the backing service. */ + REST, + /** + * When used, the clients will send gRPC to the backing service. This is usually + * more efficient and is the default transport. + */ + GRPC + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java index e2ac02ed6..e78250389 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2AutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java index 6c4ae48fb..49e93b8dc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2ConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java index 0dc079b03..531c8d4fb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPalm2EmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java index 417f9a89b..5627defd6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/palm2/VertexAiPlam2ChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.palm2; import org.springframework.ai.vertexai.palm2.VertexAiPaLm2ChatOptions; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java index 4d7f9f436..e5d53314b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.ai.watsonx.WatsonxAiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java index 3f9dc8fe9..80b9abbf5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; +import java.util.List; + import org.springframework.ai.watsonx.WatsonxAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; -import java.util.List; - /** * Chat properties for Watsonx.AI Chat. * diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java index 0ffc3656d..5e4fecf21 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -42,7 +43,7 @@ public class WatsonxAiConnectionProperties { private String IAMToken; public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { @@ -50,7 +51,7 @@ public class WatsonxAiConnectionProperties { } public String getStreamEndpoint() { - return streamEndpoint; + return this.streamEndpoint; } public void setStreamEndpoint(String streamEndpoint) { @@ -58,7 +59,7 @@ public class WatsonxAiConnectionProperties { } public String getTextEndpoint() { - return textEndpoint; + return this.textEndpoint; } public void setTextEndpoint(String textEndpoint) { @@ -66,7 +67,7 @@ public class WatsonxAiConnectionProperties { } public String getEmbeddingEndpoint() { - return embeddingEndpoint; + return this.embeddingEndpoint; } public void setEmbeddingEndpoint(String embeddingEndpoint) { @@ -74,7 +75,7 @@ public class WatsonxAiConnectionProperties { } public String getProjectId() { - return projectId; + return this.projectId; } public void setProjectId(String projectId) { @@ -82,7 +83,7 @@ public class WatsonxAiConnectionProperties { } public String getIAMToken() { - return IAMToken; + return this.IAMToken; } public void setIAMToken(String IAMToken) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java index 42425a265..983291d21 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; @@ -40,12 +56,12 @@ public class WatsonxAiEmbeddingProperties { return this.options; } - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public boolean isEnabled() { return this.enabled; } + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java index 98afeaf79..7b89cad77 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import java.util.List; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -42,8 +45,6 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import io.micrometer.observation.ObservationRegistry; - /** * @author Geng Rong */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java index d1179e990..86004fec9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; @@ -44,7 +45,7 @@ public class ZhiPuAiChatProperties extends ZhiPuAiParentProperties { .build(); public ZhiPuAiChatOptions getOptions() { - return options; + return this.options; } public void setOptions(ZhiPuAiChatOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java index 6d850f3d7..798afd43a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.boot.context.properties.ConfigurationProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java index 4e1c6ef80..86b8b0086 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.document.MetadataMode; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java index 7463d4573..4751dbe6f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.springframework.ai.zhipuai.ZhiPuAiImageOptions; @@ -39,7 +40,7 @@ public class ZhiPuAiImageProperties extends ZhiPuAiParentProperties { private ZhiPuAiImageOptions options = ZhiPuAiImageOptions.builder().build(); public ZhiPuAiImageOptions getOptions() { - return options; + return this.options; } public void setOptions(ZhiPuAiImageOptions options) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java index 70d43d770..c89102ec1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; /** @@ -25,7 +26,7 @@ class ZhiPuAiParentProperties { private String baseUrl; public String getApiKey() { - return apiKey; + return this.apiKey; } public void setApiKey(String apiKey) { @@ -33,7 +34,7 @@ class ZhiPuAiParentProperties { } public String getBaseUrl() { - return baseUrl; + return this.baseUrl; } public void setBaseUrl(String baseUrl) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index d5b849aa3..ef4dd4b51 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration org.springframework.ai.autoconfigure.oci.genai.OCIGenAiAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java index f35b324c9..d550d32b3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -35,7 +36,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") public class AnthropicAutoConfigurationIT { @@ -48,7 +49,7 @@ public class AnthropicAutoConfigurationIT { @Test void call() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -58,7 +59,7 @@ public class AnthropicAutoConfigurationIT { @Test void callWith8KResponseContext() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.anthropic.beta-version=" + AnthropicApi.BETA_MAX_TOKENS, "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue()) .run(context -> { @@ -72,7 +73,7 @@ public class AnthropicAutoConfigurationIT { @Test void stream() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java index ca9cca03f..e2909971d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/AnthropicPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java index 3a0a80052..0c4e933c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.List; import java.util.function.Function; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -40,6 +40,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") class FunctionCallWithFunctionBeanIT { @@ -53,7 +55,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -66,14 +68,14 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -83,7 +85,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -96,7 +98,7 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -121,4 +123,4 @@ class FunctionCallWithFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java index 9f3cf79c3..9dccd4c50 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.anthropic.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; import org.springframework.ai.anthropic.api.AnthropicApi; @@ -34,6 +34,8 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") public class FunctionCallWithPromptFunctionIT { @@ -45,7 +47,7 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues( "spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue()) .run(context -> { @@ -64,10 +66,10 @@ public class FunctionCallWithPromptFunctionIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java index 752ddb2d7..e27e66300 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.anthropic.tool; import java.util.function.Function; @@ -30,14 +31,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +73,23 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test void httpRequestContainsUserAgentAndCustomHeaders() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", "spring.ai.azure.openai.custom-headers.fizz=buzz") .run(context -> { @@ -125,11 +131,11 @@ class AzureOpenAiAutoConfigurationIT { @Test void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(10); @@ -147,7 +153,7 @@ class AzureOpenAiAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -165,7 +171,7 @@ class AzureOpenAiAutoConfigurationIT { @Test @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME", matches = ".+") void transcribe() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiAudioTranscriptionModel transcriptionModel = context .getBean(AzureOpenAiAudioTranscriptionModel.class); Resource audioFile = new ClassPathResource("/speech/jfk.flac"); @@ -179,17 +185,17 @@ class AzureOpenAiAutoConfigurationIT { void chatActivation() { // Disable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty(); }); // The chat auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); }); // Explicitly enable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); }); } @@ -198,17 +204,17 @@ class AzureOpenAiAutoConfigurationIT { void embeddingActivation() { // Disable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty(); }); // The embedding auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); }); // Explicitly enable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); }); } @@ -217,19 +223,21 @@ class AzureOpenAiAutoConfigurationIT { void audioTranscriptionActivation() { // Disable the transcription auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); + }); // The transcription auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); }); // Explicitly enable the transcription auto-configuration. - contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=true").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=true") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); + }); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index 581f178c0..e83c75e23 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java index f042706aa..19c651b3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiDirectOpenAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import java.util.List; @@ -21,19 +22,19 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.azure.openai.AzureOpenAiChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; +import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -73,20 +74,20 @@ public class AzureOpenAiDirectOpenAiAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(10); @@ -104,7 +105,7 @@ public class AzureOpenAiDirectOpenAiAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java index dafd8f49a..fa2c77b1f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.azure.tool; import org.springframework.util.StringUtils; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index a7e04c351..dda06c824 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.azure.tool; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; +package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; import java.util.function.Function; @@ -25,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; @@ -39,6 +38,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; + @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") class FunctionCallWithFunctionBeanIT { @@ -55,7 +57,8 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -66,14 +69,14 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -82,7 +85,8 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -93,7 +97,7 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -119,4 +123,4 @@ class FunctionCallWithFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java index 9a54bef4e..62071ed41 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; @@ -25,8 +26,8 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; @@ -54,7 +55,8 @@ public class FunctionCallWithFunctionWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -65,7 +67,7 @@ public class FunctionCallWithFunctionWrapperIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30", "10", "15"); @@ -86,4 +88,4 @@ public class FunctionCallWithFunctionWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java index 4c2b622ec..00a914535 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; @@ -25,8 +26,8 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -50,7 +51,8 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + this.contextRunner + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -67,10 +69,10 @@ public class FunctionCallWithPromptFunctionIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java index 9333522be..0d390e57e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.tool; import java.util.function.Function; @@ -30,15 +31,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @Override + public Response apply(Request request) { - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -66,28 +73,24 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { BedrockAnthropicChatModel anthropicChatModel = context.getBean(BedrockAnthropicChatModel.class); - ChatResponse response = anthropicChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = anthropicChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropicChatModel anthropicChatModel = context.getBean(BedrockAnthropicChatModel.class); - Flux response = anthropicChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = anthropicChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java index 3defe79b3..8475517bd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.anthropic3; import java.util.List; @@ -21,19 +22,19 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -68,20 +69,21 @@ public class BedrockAnthropic3ChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropic3ChatModel anthropicChatModel = context.getBean(BedrockAnthropic3ChatModel.class); - ChatResponse response = anthropicChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = anthropicChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAnthropic3ChatModel anthropicChatModel = context.getBean(BedrockAnthropic3ChatModel.class); - Flux response = anthropicChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = anthropicChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index 83b487c90..bf7482915 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; import java.util.List; @@ -21,21 +22,21 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -71,20 +72,21 @@ public class BedrockCohereChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - ChatResponse response = cohereChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = cohereChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - Flux response = cohereChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = cohereChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java index 14d388955..4523a1955 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.cohere; +import java.util.List; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; @@ -25,9 +30,6 @@ import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.Coher import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import software.amazon.awssdk.regions.Region; - -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -51,7 +53,7 @@ public class BedrockCohereEmbeddingAutoConfigurationIT { @Test public void singleEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereEmbeddingModel embeddingModel = context.getBean(BedrockCohereEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -63,7 +65,7 @@ public class BedrockCohereEmbeddingAutoConfigurationIT { @Test public void batchEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockCohereEmbeddingModel embeddingModel = context.getBean(BedrockCohereEmbeddingModel.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java index ace30a03d..2c2af6bda 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,24 +16,25 @@ package org.springframework.ai.autoconfigure.bedrock.jurassic2; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatProperties; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import software.amazon.awssdk.regions.Region; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -68,9 +69,10 @@ public class BedrockAi21Jurassic2ChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockAi21Jurassic2ChatModel ai21Jurassic2ChatModel = context.getBean(BedrockAi21Jurassic2ChatModel.class); - ChatResponse response = ai21Jurassic2ChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = ai21Jurassic2ChatModel + .call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java index f1ed73b8b..6c4fecc11 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.llama; import java.util.List; @@ -21,19 +22,19 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -70,20 +71,21 @@ public class BedrockLlamaChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockLlamaChatModel llamaChatModel = context.getBean(BedrockLlamaChatModel.class); - ChatResponse response = llamaChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = llamaChatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockLlamaChatModel llamaChatModel = context.getBean(BedrockLlamaChatModel.class); - Flux response = llamaChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = llamaChatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 94a2fda1b..78749d873 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import java.util.List; @@ -21,19 +22,19 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -69,20 +70,20 @@ public class BedrockTitanChatAutoConfigurationIT { @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(this.userMessage, this.systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 5a5a2ad4c..525898ac0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.bedrock.titan; import java.util.Base64; @@ -51,7 +52,7 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { @Test public void singleTextEmbedding() { - contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=TEXT").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=TEXT").run(context -> { BedrockTitanEmbeddingModel embeddingModel = context.getBean(BedrockTitanEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -63,7 +64,7 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { @Test public void singleImageEmbedding() { - contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=IMAGE").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.bedrock.titan.embedding.inputType=IMAGE").run(context -> { BedrockTitanEmbeddingModel embeddingModel = context.getBean(BedrockTitanEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java index 159e2cb1d..4dd65c789 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.client; import java.util.List; @@ -50,28 +51,28 @@ public class ChatClientAutoConfigurationIT { @Test void implicitlyEnabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); }); } @Test void explicitlyEnabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true").run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); }); } @Test void explicitlyDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false").run(context -> { assertThat(context.getBeansOfType(ChatClient.Builder.class)).isEmpty(); }); } @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ChatClient.Builder builder = context.getBean(ChatClient.Builder.class); assertThat(builder).isNotNull(); @@ -87,7 +88,7 @@ public class ChatClientAutoConfigurationIT { @Test void testChatClientCustomizers() { - contextRunner.withUserConfiguration(Config.class).run(context -> { + this.contextRunner.withUserConfiguration(Config.class).run(context -> { ChatClient.Builder builder = context.getBean(ChatClient.Builder.class); @@ -107,6 +108,7 @@ public class ChatClientAutoConfigurationIT { } record ActorsFilms(String actor, List movies) { + } @Configuration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java index 94a7ec0db..658a758b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.client; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link ChatClientAutoConfiguration} observability support. * @@ -34,14 +36,14 @@ class ChatClientObservationAutoConfigurationTests { @Test void inputContentFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class); }); } @Test void inputContentFilterEnabled() { - contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true").run(context -> { assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java index 86df7f40c..abc4b6c43 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.chat.memory.cassandra; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.util.List; +import com.datastax.driver.core.utils.UUIDs; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.chat.memory.CassandraChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; @@ -27,12 +32,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import com.datastax.driver.core.utils.UUIDs; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Mick Semb Wever @@ -53,7 +54,7 @@ class CassandraChatMemoryAutoConfigurationIT { @Test void addAndGet() { - contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) + this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java index e3df66722..c3b473089 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/cassandra/CassandraChatMemoryPropertiesTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.cassandra; import java.time.Duration; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java index 7e4f01a69..cafd64873 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.observation; import io.micrometer.core.instrument.composite.CompositeMeterRegistry; @@ -20,7 +21,12 @@ import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.observation.*; + +import org.springframework.ai.chat.observation.ChatModelCompletionObservationFilter; +import org.springframework.ai.chat.observation.ChatModelCompletionObservationHandler; +import org.springframework.ai.chat.observation.ChatModelMeterObservationHandler; +import org.springframework.ai.chat.observation.ChatModelPromptContentObservationFilter; +import org.springframework.ai.chat.observation.ChatModelPromptContentObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -38,35 +44,35 @@ class ChatObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { + this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { assertThat(context).hasSingleBean(ChatModelMeterObservationHandler.class); }); } @Test void meterObservationHandlerDisabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelMeterObservationHandler.class); }); } @Test void promptFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationFilter.class); }); } @Test void promptHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); }); } @Test void promptHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-prompt=true") .run(context -> { @@ -76,28 +82,28 @@ class ChatObservationAutoConfigurationTests { @Test void promptHandlerDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true").run(context -> { assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); }); } @Test void completionFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationFilter.class); }); } @Test void completionHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); }); } @Test void completionHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-completion=true") .run(context -> { @@ -107,7 +113,7 @@ class ChatObservationAutoConfigurationTests { @Test void completionHandlerDisabled() { - contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true").run(context -> { assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java index c479690f6..ad1910337 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.embedding.observation; import io.micrometer.core.instrument.composite.CompositeMeterRegistry; import org.junit.jupiter.api.Test; + import org.springframework.ai.embedding.observation.EmbeddingModelMeterObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -35,14 +37,14 @@ class EmbeddingObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { + this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class); }); } @Test void meterObservationHandlerDisabled() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java index 300962345..a0b5c014d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/huggingface/HuggingfaceChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.huggingface; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.huggingface; import java.util.List; import java.util.stream.Collectors; @@ -25,6 +24,8 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -34,7 +35,7 @@ import org.springframework.ai.huggingface.HuggingfaceChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") @@ -43,7 +44,7 @@ public class HuggingfaceChatAutoConfigurationIT { private static final Log logger = LogFactory.getLog(HuggingfaceChatAutoConfigurationIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( - // @formatter:off + // @formatter:off "spring.ai.huggingface.chat.api-key=" + System.getenv("HUGGINGFACE_API_KEY"), "spring.ai.huggingface.chat.url=" + System.getenv("HUGGINGFACE_CHAT_URL")) // @formatter:on @@ -51,7 +52,7 @@ public class HuggingfaceChatAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -62,7 +63,7 @@ public class HuggingfaceChatAutoConfigurationIT { @Disabled("Until streaming support is added") @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java index 0c26b992a..b4bd23203 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.image.observation; import org.junit.jupiter.api.Test; + import org.springframework.ai.image.observation.ImageModelPromptContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -34,14 +36,14 @@ class ImageObservationAutoConfigurationTests { @Test void promptFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(ImageModelPromptContentObservationFilter.class); }); } @Test void promptFilterEnabled() { - contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true").run(context -> { assertThat(context).hasSingleBean(ImageModelPromptContentObservationFilter.class); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java index 6d97b82de..167102cb8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -31,10 +37,6 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -53,7 +55,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -70,7 +72,7 @@ public class FunctionCallbackInPromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -79,7 +81,7 @@ public class FunctionCallbackInPromptIT { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -104,7 +106,7 @@ public class FunctionCallbackInPromptIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -112,4 +114,4 @@ public class FunctionCallbackInPromptIT { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java index ceb00d476..5b33d4c67 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -35,11 +42,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { // FIXME: multiple function calls may stop prematurely due to model performance @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -71,7 +73,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -102,14 +104,14 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } // FIXME: multiple function calls may stop prematurely due to model performance @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -128,7 +130,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -146,7 +148,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -173,4 +175,4 @@ class FunctionCallbackWithPlainFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java index 612622d69..780476ea2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -34,10 +40,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -57,7 +59,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -67,7 +69,7 @@ public class FunctionCallbackWrapperIT { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -76,7 +78,7 @@ public class FunctionCallbackWrapperIT { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.minimax.chat.options.model=abab6.5s-chat").run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); @@ -94,7 +96,7 @@ public class FunctionCallbackWrapperIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -118,4 +120,4 @@ public class FunctionCallbackWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java index 507ed41ff..06e8c2b3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.List; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -29,10 +35,6 @@ import org.springframework.ai.minimax.MiniMaxEmbeddingModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -51,7 +53,7 @@ public class MiniMaxAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -61,7 +63,7 @@ public class MiniMaxAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -75,7 +77,7 @@ public class MiniMaxAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MiniMaxEmbeddingModel embeddingModel = context.getBean(MiniMaxEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java index f8a2f5e2a..47131bedc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java index b7f792d3a..61a5394db 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.minimax; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,16 +31,21 @@ import java.util.function.Function; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Get the weather in location") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -60,7 +61,7 @@ public class MistralAiAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -74,7 +75,7 @@ public class MistralAiAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MistralAiEmbeddingModel embeddingModel = context.getBean(MistralAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java index ee4cf9aa7..e69a711ff 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.mistralai; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java index 69709a910..caa18b90b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,11 +39,16 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") class PaymentStatusBeanIT { + // Assuming we have the following data + public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", + new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", + new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(PaymentStatusBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -53,7 +59,7 @@ class PaymentStatusBeanIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -66,33 +72,20 @@ class PaymentStatusBeanIT { .withFunction("retrievePaymentDate") .build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } - // Assuming we have the following data - public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", - new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", - new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { + } @Configuration static class Config { - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } - @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { @@ -105,6 +98,18 @@ class PaymentStatusBeanIT { return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); } + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java index 3fc46b03b..5a428b91b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -38,7 +39,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; /** * Same test as {@link PaymentStatusBeanIT.java} but using {@link OpenAiChatModel} for @@ -49,6 +50,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") class PaymentStatusBeanOpenAiIT { + // Assuming we have the following data + public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", + new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", + new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(PaymentStatusBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() @@ -60,7 +66,7 @@ class PaymentStatusBeanOpenAiIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue()) .run(context -> { @@ -73,33 +79,20 @@ class PaymentStatusBeanOpenAiIT { .withFunction("retrievePaymentDate") .build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } - // Assuming we have the following data - public static final Map DATA = Map.of("T1001", new StatusDate("Paid", "2021-10-05"), "T1002", - new StatusDate("Unpaid", "2021-10-06"), "T1003", new StatusDate("Paid", "2021-10-07"), "T1004", - new StatusDate("Paid", "2021-10-05"), "T1005", new StatusDate("Pending", "2021-10-08")); - record StatusDate(String status, String date) { + } @Configuration static class Config { - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - public record Date(@JsonProperty(required = true, value = "date") String date) { - } - @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { @@ -112,6 +105,18 @@ class PaymentStatusBeanOpenAiIT { return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); } + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + public record Date(@JsonProperty(required = true, value = "date") String date) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index 0cf0d18b0..2efdad493 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.Map; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -36,35 +37,26 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".*") public class PaymentStatusPromptIT { - private final Logger logger = LoggerFactory.getLogger(WeatherServicePromptIT.class); - - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(MistralAiAutoConfiguration.class)); - - public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) { - } - - public record Status(@JsonProperty(required = true, value = "status") String status) { - } - - record StatusDate(String status, String date) { - } - // Assuming we have the following payment data. public static final Map DATA = Map.of(new Transaction("T1001"), new StatusDate("Paid", "2021-10-05"), new Transaction("T1002"), new StatusDate("Unpaid", "2021-10-06"), new Transaction("T1003"), new StatusDate("Paid", "2021-10-07"), new Transaction("T1004"), new StatusDate("Paid", "2021-10-05"), new Transaction("T1005"), new StatusDate("Pending", "2021-10-08")); + private final Logger logger = LoggerFactory.getLogger(WeatherServicePromptIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(MistralAiAutoConfiguration.class)); + @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.SMALL.getValue()) .run(context -> { @@ -74,6 +66,7 @@ public class PaymentStatusPromptIT { var promptOptions = MistralAiChatOptions.builder() .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new Function() { + public Status apply(Transaction transaction) { return new Status(DATA.get(transaction).status()); } @@ -85,11 +78,23 @@ public class PaymentStatusPromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); }); } -} \ No newline at end of file + public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) { + + } + + public record Status(@JsonProperty(required = true, value = "status") String status) { + + } + + record StatusDate(String status, String date) { + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index 13750c3b0..f546eb8bb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.mistralai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.mistralai.tool; import java.util.List; import java.util.function.Function; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Request; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Response; @@ -40,9 +43,7 @@ import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.Porta import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,7 +60,7 @@ public class WeatherServicePromptIT { @Test void promptFunctionCall() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -80,7 +81,7 @@ public class WeatherServicePromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); // assertThat(response.getResult().getOutput().getContent()).contains("30.0", @@ -90,7 +91,7 @@ public class WeatherServicePromptIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.mistralai.chat.options.model=" + MistralAiApi.ChatModel.LARGE.getValue()) .run(context -> { @@ -108,7 +109,7 @@ public class WeatherServicePromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); }); @@ -116,17 +117,6 @@ public class WeatherServicePromptIT { public static class MyWeatherService implements Function { - // @formatter:off - public enum Unit { C, F } - - @JsonInclude(Include.NON_NULL) - public record Request( - @JsonProperty(required = true, value = "location") String location, - @JsonProperty(required = true, value = "unit") Unit unit) {} - - public record Response(double temperature, Unit unit) {} - // @formatter:on - @Override public Response apply(Request request) { if (request.location().contains("Paris")) { @@ -141,6 +131,19 @@ public class WeatherServicePromptIT { throw new IllegalArgumentException("Invalid request: " + request); } + // @formatter:off + public enum Unit { C, F } + + @JsonInclude(Include.NON_NULL) + public record Request( + @JsonProperty(required = true, value = "location") String location, + @JsonProperty(required = true, value = "unit") Unit unit) {} + // @formatter:on + + public record Response(double temperature, Unit unit) { + + } + } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java index 196e04f82..f24c6599c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; +import java.util.Objects; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -27,10 +33,6 @@ import org.springframework.ai.moonshot.MoonshotChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +51,7 @@ public class MoonshotAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel client = context.getBean(MoonshotChatModel.class); String response = client.call("Hello"); assertThat(response).isNotEmpty(); @@ -59,7 +61,7 @@ public class MoonshotAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel client = context.getBean(MoonshotChatModel.class); Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); String response = Objects.requireNonNull(responseFlux.collectList().block()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java index 213ccd945..79e7bd142 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/MoonshotPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot; import org.junit.jupiter.api.Test; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.moonshot.MoonshotChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java index b31f9c28f..2853c4a46 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -32,10 +38,6 @@ import org.springframework.ai.moonshot.MoonshotChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -71,7 +73,7 @@ public class FunctionCallbackInPromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -80,7 +82,7 @@ public class FunctionCallbackInPromptIT { @Test void streamingFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -105,7 +107,7 @@ public class FunctionCallbackInPromptIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -113,4 +115,4 @@ public class FunctionCallbackInPromptIT { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java index e21c283c7..e94be4220 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,11 +43,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -71,7 +73,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -102,13 +104,13 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -127,7 +129,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -145,7 +147,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -172,4 +174,4 @@ class FunctionCallbackWithPlainFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java index a5b7c8779..9de829cc7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -35,11 +42,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -59,7 +61,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -69,7 +71,7 @@ public class FunctionCallbackWrapperIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -78,7 +80,7 @@ public class FunctionCallbackWrapperIT { @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MoonshotChatModel chatModel = context.getBean(MoonshotChatModel.class); @@ -97,7 +99,7 @@ public class FunctionCallbackWrapperIT { .map(AssistantMessage::getContent) .filter(Objects::nonNull) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -121,4 +123,4 @@ public class FunctionCallbackWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java index 0bdfcbb1c..3d8e96ba6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.moonshot.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,14 +31,21 @@ import java.util.function.Function; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -65,28 +73,23 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { OCIEmbeddingModel embeddingModel = context.getBean(OCIEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse response = embeddingModel diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java index f43234039..3a096dc09 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java @@ -1,21 +1,33 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.ollama; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; -import org.testcontainers.ollama.OllamaContainer; public class BaseOllamaIT { - // Toggle for running tests locally on native Ollama for a faster feedback loop. - private static final boolean useTestcontainers = true; - public static final OllamaContainer ollamaContainer; - static { - ollamaContainer = new OllamaContainer(OllamaImage.IMAGE).withReuse(true); - ollamaContainer.start(); - } + // Toggle for running tests locally on native Ollama for a faster feedback loop. + private static final boolean useTestcontainers = true; /** * Change the return value to false in order to run multiple Ollama IT tests locally @@ -41,4 +53,9 @@ public class BaseOllamaIT { return baseUrl; } + static { + ollamaContainer = new OllamaContainer(OllamaImage.IMAGE).withReuse(true); + ollamaContainer.start(); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java index 10b07fb45..b596b14f0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama; import java.io.IOException; import java.util.List; @@ -24,6 +23,9 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -35,9 +37,8 @@ import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -53,11 +54,6 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -69,22 +65,27 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { private final UserMessage userMessage = new UserMessage("What's the capital of Denmark?"); + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test public void chatCompletion() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - ChatResponse response = chatModel.call(new Prompt(userMessage)); + ChatResponse response = chatModel.call(new Prompt(this.userMessage)); assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen"); }); } @Test public void chatCompletionStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - Flux response = chatModel.stream(new Prompt(userMessage)); + Flux response = chatModel.stream(new Prompt(this.userMessage)); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -102,7 +103,7 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { @Test public void chatCompletionWithPull() { - contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") + this.contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") .withPropertyValues("spring.ai.ollama.chat.options.model=tinyllama") .run(context -> { var model = "tinyllama"; @@ -111,7 +112,7 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { assertThat(modelManager.isModelAvailable(model)).isTrue(); OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); - ChatResponse response = chatModel.call(new Prompt(userMessage)); + ChatResponse response = chatModel.call(new Prompt(this.userMessage)); assertThat(response.getResult().getOutput().getContent()).contains("Copenhagen"); modelManager.deleteModel(model); }); @@ -119,17 +120,17 @@ public class OllamaChatAutoConfigurationIT extends BaseOllamaIT { @Test void chatActivation() { - contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isNotEmpty(); }); - contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(OllamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaChatModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java index 77e7e06a1..14493bcb2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 0ea701a67..0a2521aef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import java.io.IOException; @@ -21,12 +22,13 @@ import java.util.List; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.management.OllamaModelManager; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -46,19 +48,19 @@ public class OllamaEmbeddingAutoConfigurationIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() throws IOException, InterruptedException { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.ollama.embedding.options.model=" + MODEL_NAME, "spring.ai.ollama.base-url=" + baseUrl) .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)); + @BeforeAll + public static void beforeAll() throws IOException, InterruptedException { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test public void singleTextEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaEmbeddingModel embeddingModel = context.getBean(OllamaEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); @@ -70,7 +72,7 @@ public class OllamaEmbeddingAutoConfigurationIT extends BaseOllamaIT { @Test public void embeddingWithPull() { - contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") + this.contextRunner.withPropertyValues("spring.ai.ollama.init.pull-model-strategy=when_missing") .withPropertyValues("spring.ai.ollama.embedding.options.model=all-minilm") .run(context -> { var model = "all-minilm"; @@ -87,17 +89,17 @@ public class OllamaEmbeddingAutoConfigurationIT extends BaseOllamaIT { @Test void embeddingActivation() { - contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.ollama.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(OllamaEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OllamaEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index 487485a06..bd2a8bfd2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index fb9db0f36..ebabcc722 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama; public class OllamaImage { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java index b7bc4e408..53f2973cc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama.tool; import java.util.List; import java.util.stream.Collectors; @@ -26,6 +25,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -38,9 +40,8 @@ import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Testcontainers @DisabledIf("isDisabled") @@ -52,11 +53,6 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -66,9 +62,14 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { // @formatter:on .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); + @BeforeAll + public static void beforeAll() { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -95,7 +96,7 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamingFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -129,4 +130,4 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java index 82fd7eb11..451a970cb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.ollama.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.ollama.tool; import java.util.List; import java.util.stream.Collectors; @@ -26,6 +25,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Testcontainers; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -38,15 +40,13 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.ollama.OllamaChatModel; -import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Testcontainers; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @Testcontainers @DisabledIf("isDisabled") @@ -58,11 +58,6 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { static String baseUrl; - @BeforeAll - public static void beforeAll() { - baseUrl = buildConnectionWithModel(MODEL_NAME); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.baseUrl=" + baseUrl, @@ -73,9 +68,14 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) .withUserConfiguration(Config.class); + @BeforeAll + public static void beforeAll() { + baseUrl = buildConnectionWithModel(MODEL_NAME); + } + @Test void functionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -94,7 +94,7 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { @Disabled("Ollama API does not support streaming function calls yet") @Test void streamFunctionCallTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -120,7 +120,7 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); @@ -156,4 +156,4 @@ public class FunctionCallbackWrapperIT extends BaseOllamaIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java index dc780891b..e4a1487c2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.ollama.tool; import java.util.function.Function; @@ -30,16 +31,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 10; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -65,7 +66,7 @@ public class OpenAiAutoConfigurationIT { @Test void transcribe() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiAudioTranscriptionModel transcriptionModel = context.getBean(OpenAiAudioTranscriptionModel.class); Resource audioFile = new ClassPathResource("/speech/jfk.flac"); String response = transcriptionModel.call(audioFile); @@ -76,7 +77,7 @@ public class OpenAiAutoConfigurationIT { @Test void speech() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiAudioSpeechModel speechModel = context.getBean(OpenAiAudioSpeechModel.class); byte[] response = speechModel.call("H"); assertThat(response).isNotNull(); @@ -102,7 +103,7 @@ public class OpenAiAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -116,7 +117,7 @@ public class OpenAiAutoConfigurationIT { @Test void streamingWithTokenUsage() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); @@ -138,7 +139,7 @@ public class OpenAiAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { OpenAiEmbeddingModel embeddingModel = context.getBean(OpenAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -155,7 +156,7 @@ public class OpenAiAutoConfigurationIT { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.openai.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.image.options.size=1024x1024").run(context -> { OpenAiImageModel imageModel = context.getBean(OpenAiImageModel.class); ImageResponse imageResponse = imageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); @@ -167,7 +168,7 @@ public class OpenAiAutoConfigurationIT { @Test void generateImageWithModel() { // The 256x256 size is supported by dall-e-2, but not by dall-e-3. - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.image.options.model=dall-e-2", "spring.ai.openai.image.options.size=256x256") .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index 5c38bfee6..3ba07792b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,26 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.OpenAiImageModel; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiConnectionProperties}, {@link OpenAiChatProperties} and * {@link OpenAiEmbeddingProperties}. diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java index 03a3eba1e..03d6af56c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; + import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiChatModel; @@ -28,6 +28,8 @@ import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit Tests for {@link OpenAiChatProperties} #options#responseFormat support. * diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index cdf76ec24..538c8456a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.function.Function; import java.util.stream.Collectors; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.openai.OpenAiChatModel; @@ -31,6 +31,8 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackInPrompt2IT { @@ -42,7 +44,7 @@ public class FunctionCallbackInPrompt2IT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -60,7 +62,7 @@ public class FunctionCallbackInPrompt2IT { .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -68,7 +70,7 @@ public class FunctionCallbackInPrompt2IT { @Test void functionCallTest2() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -85,7 +87,7 @@ public class FunctionCallbackInPrompt2IT { }) .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("18"); }); @@ -94,7 +96,7 @@ public class FunctionCallbackInPrompt2IT { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -107,10 +109,10 @@ public class FunctionCallbackInPrompt2IT { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index e5c0c4fca..4de98c177 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -37,7 +38,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackInPromptIT { @@ -50,7 +51,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.1") .run(context -> { @@ -70,7 +71,7 @@ public class FunctionCallbackInPromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -79,7 +80,7 @@ public class FunctionCallbackInPromptIT { @Test void streamingFunctionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.5") .run(context -> { @@ -107,10 +108,10 @@ public class FunctionCallbackInPromptIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index beb829278..c4b843821 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.List; import java.util.Map; import java.util.function.BiFunction; @@ -27,6 +25,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; @@ -46,7 +46,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") class FunctionCallbackWithPlainFunctionBeanIT { @@ -269,4 +269,4 @@ class FunctionCallbackWithPlainFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java index aaf84d98a..0056fc20b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.stream.Collectors; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.model.function.FunctionCallback; @@ -34,6 +34,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackWrapper2IT { @@ -47,7 +49,7 @@ public class FunctionCallbackWrapper2IT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -62,7 +64,7 @@ public class FunctionCallbackWrapper2IT { .call().content(); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -70,7 +72,7 @@ public class FunctionCallbackWrapper2IT { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -82,7 +84,7 @@ public class FunctionCallbackWrapper2IT { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -103,4 +105,4 @@ public class FunctionCallbackWrapper2IT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java index 5020b4b56..01a1bd1a9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.openai.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.openai.tool; import java.util.List; import java.util.stream.Collectors; @@ -24,6 +23,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -40,7 +41,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; +import static org.assertj.core.api.Assertions.assertThat; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class FunctionCallbackWrapperIT { @@ -55,7 +56,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -64,7 +65,7 @@ public class FunctionCallbackWrapperIT { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -73,7 +74,7 @@ public class FunctionCallbackWrapperIT { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -91,7 +92,7 @@ public class FunctionCallbackWrapperIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -115,4 +116,4 @@ public class FunctionCallbackWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java index 60fd35af1..f0026ca9f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai.tool; import java.util.function.Function; @@ -30,16 +31,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 10; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ public class MockWeatherService implements Function jdbcTemplate) + .withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)); contextRunner.run(context -> { PostgresMlEmbeddingModel embeddingModel = context.getBean(PostgresMlEmbeddingModel.class); @@ -85,7 +86,7 @@ public class PostgresMlAutoConfigurationIT { @Test void embeddingActivation() { - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .withPropertyValues("spring.ai.postgresml.embedding.enabled=false") .run(context -> { @@ -93,7 +94,7 @@ public class PostgresMlAutoConfigurationIT { assertThat(context.getBeansOfType(PostgresMlEmbeddingModel.class)).isEmpty(); }); - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .withPropertyValues("spring.ai.postgresml.embedding.enabled=true") .run(context -> { @@ -101,7 +102,7 @@ public class PostgresMlAutoConfigurationIT { assertThat(context.getBeansOfType(PostgresMlEmbeddingModel.class)).isNotEmpty(); }); - new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> jdbcTemplate) + new ApplicationContextRunner().withBean(JdbcTemplate.class, () -> this.jdbcTemplate) .withConfiguration(AutoConfigurations.of(PostgresMlAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(PostgresMlEmbeddingProperties.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java index 0576e7b5b..be6a53585 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.postgresml; import java.util.Map; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java index a854b8ba5..002a5f578 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.image.ImagePrompt; @@ -33,11 +40,6 @@ import org.springframework.ai.qianfan.QianFanImageModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -58,7 +60,7 @@ public class QianFanAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanChatModel client = context.getBean(QianFanChatModel.class); String response = client.call("Hello"); assertThat(response).isNotEmpty(); @@ -68,7 +70,7 @@ public class QianFanAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanChatModel client = context.getBean(QianFanChatModel.class); Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); String response = Objects.requireNonNull(responseFlux.collectList().block()) @@ -82,7 +84,7 @@ public class QianFanAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { QianFanEmbeddingModel embeddingClient = context.getBean(QianFanEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingClient @@ -99,7 +101,7 @@ public class QianFanAutoConfigurationIT { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.qianfan.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.qianfan.image.options.size=1024x1024").run(context -> { QianFanImageModel imageModel = context.getBean(QianFanImageModel.class); ImageResponse imageResponse = imageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java index c5acafd78..1a0ee813f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.qianfan; import org.junit.jupiter.api.Test; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.qianfan.QianFanChatModel; import org.springframework.ai.qianfan.QianFanEmbeddingModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java index f15e605fc..64759b680 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java index c663dfb3a..c3baab91c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.retry; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java index 0423170af..34eb67663 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.image.Image; -import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.stabilityai.StyleEnum; @@ -38,7 +40,7 @@ public class StabilityAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ImageModel imageModel = context.getBean(ImageModel.class); StabilityAiImageOptions imageOptions = StabilityAiImageOptions.builder() .withStylePreset(StyleEnum.PHOTOGRAPHIC) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java index 6fbb85094..c267dd765 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.stabilityai; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java index 6a7cacabb..9b3a597be 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.transformers; import java.io.File; @@ -33,15 +34,15 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class TransformersEmbeddingModelAutoConfigurationIT { - @TempDir - File tempDir; - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(TransformersEmbeddingModelAutoConfiguration.class)); + @TempDir + File tempDir; + @Test public void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var properties = context.getBean(TransformersEmbeddingModelProperties.class); assertThat(properties.getCache().isEnabled()).isTrue(); assertThat(properties.getCache().getDirectory()).isEqualTo( @@ -54,14 +55,15 @@ public class TransformersEmbeddingModelAutoConfigurationIT { assertThat(embeddings.size()).isEqualTo(2); // batch size assertThat(embeddings.get(0).length).isEqualTo(embeddingModel.dimensions()); // dimensions - // size + // size }); } @Test public void remoteOnnxModel() { // https://huggingface.co/intfloat/e5-small-v2 - contextRunner.withPropertyValues("spring.ai.embedding.transformer.cache.directory=" + tempDir.getAbsolutePath(), + this.contextRunner.withPropertyValues( + "spring.ai.embedding.transformer.cache.directory=" + this.tempDir.getAbsolutePath(), "spring.ai.embedding.transformer.onnx.modelUri=https://huggingface.co/intfloat/e5-small-v2/resolve/main/model.onnx", "spring.ai.embedding.transformer.tokenizer.uri=https://huggingface.co/intfloat/e5-small-v2/raw/main/tokenizer.json") .run(context -> { @@ -72,8 +74,8 @@ public class TransformersEmbeddingModelAutoConfigurationIT { .isEqualTo("https://huggingface.co/intfloat/e5-small-v2/raw/main/tokenizer.json"); assertThat(properties.getCache().isEnabled()).isTrue(); - assertThat(properties.getCache().getDirectory()).isEqualTo(tempDir.getAbsolutePath()); - assertThat(tempDir.listFiles()).hasSize(2); + assertThat(properties.getCache().getDirectory()).isEqualTo(this.tempDir.getAbsolutePath()); + assertThat(this.tempDir.listFiles()).hasSize(2); EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); assertThat(embeddingModel).isInstanceOf(TransformersEmbeddingModel.class); @@ -84,23 +86,23 @@ public class TransformersEmbeddingModelAutoConfigurationIT { assertThat(embeddings.size()).isEqualTo(2); // batch size assertThat(embeddings.get(0).length).isEqualTo(embeddingModel.dimensions()); // dimensions - // size + // size }); } @Test void embeddingActivation() { - contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=false").run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.embedding.transformer.enabled=true").run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(TransformersEmbeddingModelProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(TransformersEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java index 46ad43c9c..1590d750d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.azure; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.azure; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,10 +23,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -44,7 +43,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -55,6 +56,13 @@ import io.micrometer.observation.tck.TestObservationRegistry; @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreAutoConfigurationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AzureVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.azure.apiKey=" + System.getenv("AZURE_AI_SEARCH_API_KEY"), + "spring.ai.vectorstore.azure.url=" + System.getenv("AZURE_AI_SEARCH_ENDPOINT")) + .withPropertyValues("spring.ai.vectorstore.azure.initialize-schema=true"); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -70,13 +78,6 @@ public class AzureVectorStoreAutoConfigurationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(AzureVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.azure.apiKey=" + System.getenv("AZURE_AI_SEARCH_API_KEY"), - "spring.ai.vectorstore.azure.url=" + System.getenv("AZURE_AI_SEARCH_ENDPOINT")) - .withPropertyValues("spring.ai.vectorstore.azure.initialize-schema=true"); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -87,7 +88,7 @@ public class AzureVectorStoreAutoConfigurationIT { @Test public void addAndSearchTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.azure.initializeSchema=true", "spring.ai.vectorstore.azure.indexName=my_test_index", "spring.ai.vectorstore.azure.defaultTopK=6", "spring.ai.vectorstore.azure.defaultSimilarityThreshold=0.75") @@ -106,7 +107,7 @@ public class AzureVectorStoreAutoConfigurationIT { assertThat(vectorStore).isInstanceOf(AzureVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); @@ -120,7 +121,7 @@ public class AzureVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -131,7 +132,7 @@ public class AzureVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java index 36bf85110..47b56c74b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.cassandra; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.cassandra; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +39,9 @@ import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfigurati import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Mick Semb Wever @@ -56,11 +57,6 @@ class CassandraVectorStoreAutoConfigurationIT { @Container static CassandraContainer cassandraContainer = new CassandraContainer(DEFAULT_IMAGE_NAME.withTag("5.0")); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration( AutoConfigurations.of(CassandraVectorStoreAutoConfiguration.class, CassandraAutoConfiguration.class)) @@ -69,9 +65,14 @@ class CassandraVectorStoreAutoConfigurationIT { .withPropertyValues("spring.ai.vectorstore.cassandra.keyspace=test_autoconfigure") .withPropertyValues("spring.ai.vectorstore.cassandra.contentColumnName=doc_chunk"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) + this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost()) .withPropertyValues("spring.cassandra.port=" + getContactPointPort()) .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) .withPropertyValues("spring.ai.vectorstore.cassandra.fixedThreadPoolExecutorSize=8") @@ -79,7 +80,7 @@ class CassandraVectorStoreAutoConfigurationIT { .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.CASSANDRA, VectorStoreObservationContext.Operation.ADD); @@ -89,7 +90,7 @@ class CassandraVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -98,7 +99,7 @@ class CassandraVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).isEmpty(); @@ -109,6 +110,14 @@ class CassandraVectorStoreAutoConfigurationIT { }); } + private String getContactPointHost() { + return cassandraContainer.getContactPoint().getHostString(); + } + + private String getContactPointPort() { + return String.valueOf(cassandraContainer.getContactPoint().getPort()); + } + @Configuration(proxyBeanMethods = false) static class Config { @@ -124,12 +133,4 @@ class CassandraVectorStoreAutoConfigurationIT { } - private String getContactPointHost() { - return cassandraContainer.getContactPoint().getHostString(); - } - - private String getContactPointPort() { - return String.valueOf(cassandraContainer.getContactPoint().getPort()); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java index cfa1fc375..3c8201eda 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.cassandra; import org.junit.jupiter.api.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java index 7b061afe5..a22b26c30 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java @@ -18,16 +18,14 @@ package org.springframework.ai.autoconfigure.vectorstore.chroma; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java index 7376cf99a..1a8e09201 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,10 +16,16 @@ package org.springframework.ai.autoconfigure.vectorstore.cosmosdb; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -30,10 +36,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -61,8 +64,8 @@ public class CosmosDBVectorStoreAutoConfigurationIT { @BeforeEach public void setup() { - contextRunner.run(context -> { - vectorStore = context.getBean(VectorStore.class); + this.contextRunner.run(context -> { + this.vectorStore = context.getBean(VectorStore.class); }); } @@ -74,20 +77,20 @@ public class CosmosDBVectorStoreAutoConfigurationIT { Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); // Add the document to the vector store - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); // Perform a similarity search - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results).isNotEmpty(); assertThat(results.get(0).getId()).isEqualTo(document1.getId()); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); // Perform a similarity search again - List results2 = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results2).isEmpty(); @@ -126,16 +129,16 @@ public class CosmosDBVectorStoreAutoConfigurationIT { Document document3 = new Document("3", "A document about the US", metadata3); Document document4 = new Document("4", "A document about the US", metadata4); - vectorStore.add(List.of(document1, document2, document3, document4)); + this.vectorStore.add(List.of(document1, document2, document3, document4)); FilterExpressionBuilder b = new FilterExpressionBuilder(); - List results = vectorStore.similaritySearch(SearchRequest.query("The World") + List results = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression((b.in("country", "UK", "NL")).build())); assertThat(results).hasSize(2); assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); - List results2 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression( b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())); @@ -143,17 +146,17 @@ public class CosmosDBVectorStoreAutoConfigurationIT { assertThat(results2).hasSize(1); assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1"); - List results3 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results3 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())); assertThat(results3).hasSize(1); assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4"); - vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); // Perform a similarity search again - List results4 = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); + List results4 = this.vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); // Verify the search results assertThat(results4).isEmpty(); @@ -174,4 +177,4 @@ public class CosmosDBVectorStoreAutoConfigurationIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java index 60eca7d06..08b261c46 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -42,11 +44,10 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.elasticsearch.ElasticsearchContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -57,11 +58,6 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") .withEnv("xpack.security.enabled", "false"); - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(ElasticsearchRestClientAutoConfiguration.class, ElasticsearchVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, @@ -71,6 +67,11 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.elasticsearch.initializeSchema=true", "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY")); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + // No parametrized test based on similarity function, // by default the bean will be created using cosine. @Test @@ -80,7 +81,7 @@ class ElasticsearchVectorStoreAutoConfigurationIT { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.ELASTICSEARCH, VectorStoreObservationContext.Operation.ADD); @@ -98,7 +99,7 @@ class ElasticsearchVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); @@ -109,7 +110,7 @@ class ElasticsearchVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.ELASTICSEARCH, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java index b280052ed..54dbc4420 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java @@ -16,19 +16,23 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; - import java.util.HashMap; import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -43,14 +47,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.dockerjava.api.model.ExposedPort; -import com.github.dockerjava.api.model.PortBinding; -import com.github.dockerjava.api.model.Ports; -import com.vmware.gemfire.testcontainers.GemFireCluster; - -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Geet Rawat @@ -59,8 +58,6 @@ import io.micrometer.observation.tck.TestObservationRegistry; */ class GemFireVectorStoreAutoConfigurationIT { - private static GemFireCluster gemFireCluster; - private static final String INDEX_NAME = "spring-ai-index"; private static final int BEAM_WIDTH = 50; @@ -79,15 +76,7 @@ class GemFireVectorStoreAutoConfigurationIT { private static final int SERVER_COUNT = 1; - @AfterAll - public static void stopGemFireCluster() { - gemFireCluster.close(); - } - - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + private static GemFireCluster gemFireCluster; private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(GemFireVectorStoreAutoConfiguration.class)) @@ -102,6 +91,16 @@ class GemFireVectorStoreAutoConfigurationIT { .withPropertyValues("spring.ai.vectorstore.gemfire.port=" + HTTP_SERVICE_PORT) .withPropertyValues("spring.ai.vectorstore.gemfire.initialize-schema=true"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + @AfterAll + public static void stopGemFireCluster() { + gemFireCluster.close(); + } + @BeforeAll public static void startGemFireCluster() { Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT); @@ -144,11 +143,11 @@ class GemFireVectorStoreAutoConfigurationIT { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.ADD); @@ -166,14 +165,14 @@ class GemFireVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.DELETE); @@ -190,18 +189,24 @@ class GemFireVectorStoreAutoConfigurationIT { JsonNode rootNode = new ObjectMapper().readTree(json); Map indexDetails = new HashMap<>(); if (rootNode.isObject()) { - if (rootNode.has("name")) + if (rootNode.has("name")) { indexDetails.put("name", rootNode.get("name").asText()); - if (rootNode.has("beam-width")) + } + if (rootNode.has("beam-width")) { indexDetails.put("beam-width", rootNode.get("beam-width").asInt()); - if (rootNode.has("max-connections")) + } + if (rootNode.has("max-connections")) { indexDetails.put("max-connections", rootNode.get("max-connections").asInt()); - if (rootNode.has("vector-similarity-function")) + } + if (rootNode.has("vector-similarity-function")) { indexDetails.put("vector-similarity-function", rootNode.get("vector-similarity-function").asText()); - if (rootNode.has("buckets")) + } + if (rootNode.has("buckets")) { indexDetails.put("buckets", rootNode.get("buckets").asInt()); - if (rootNode.has("number-of-embeddings")) + } + if (rootNode.has("number-of-embeddings")) { indexDetails.put("number-of-embeddings", rootNode.get("number-of-embeddings").asInt()); + } } return indexDetails; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java index c8ef301c0..5f69d6ebb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java @@ -16,12 +16,12 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; import org.springframework.ai.vectorstore.GemFireVectorStore; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Geet Rawat * @author Soby Chacko diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java index 5801f88f1..c8d620d08 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; +import java.util.List; + import org.junit.Test; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -28,8 +32,6 @@ import org.springframework.boot.autoconfigure.data.jdbc.JdbcRepositoriesAutoConf import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import java.util.List; - @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "HANA_DATASOURCE_URL", matches = ".+") @EnabledIfEnvironmentVariable(named = "HANA_DATASOURCE_USERNAME", matches = ".+") @@ -37,22 +39,6 @@ import java.util.List; @Disabled public class HanaCloudVectorStoreAutoConfigurationIT { - @Test - public void addAndSearch() { - contextRunner.run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); - - List results = vectorStore.similaritySearch("What is Great Depression?"); - Assertions.assertEquals(1, results.size()); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); - List results2 = vectorStore.similaritySearch("Great Depression"); - Assertions.assertEquals(0, results2.size()); - }); - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(HanaCloudVectorStoreAutoConfiguration.class, OpenAiAutoConfiguration.class, RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, @@ -70,4 +56,20 @@ public class HanaCloudVectorStoreAutoConfigurationIT { new Document( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression")); + @Test + public void addAndSearch() { + this.contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(this.documents); + + List results = vectorStore.similaritySearch("What is Great Depression?"); + Assertions.assertEquals(1, results.size()); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); + List results2 = vectorStore.similaritySearch("Great Depression"); + Assertions.assertEquals(0, results2.size()); + }); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java index 1756e8cd7..eab9f8f3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/hanadb/HanaCloudVectorStorePropertiesTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.hanadb; import org.junit.Test; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java index 6697dffbc..15723b9b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.milvus; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.milvus; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -34,11 +37,9 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -52,18 +53,18 @@ public class MilvusVectorStoreAutoConfigurationIT { @Container private static MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.8"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(MilvusVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class); + List documents = List.of( new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(MilvusVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class); - @Test public void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.milvus.metricType=COSINE", "spring.ai.vectorstore.milvus.indexType=IVF_FLAT", "spring.ai.vectorstore.milvus.embeddingDimension=384", @@ -75,7 +76,7 @@ public class MilvusVectorStoreAutoConfigurationIT { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.MILVUS, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +86,7 @@ public class MilvusVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -96,7 +97,7 @@ public class MilvusVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java index 9e10fcc06..c47f3ea25 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.mongo; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.mongo; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; @@ -40,11 +43,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.mongodb.core.MongoTemplate; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Eddú Meléndez @@ -59,6 +60,19 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { @Container static MongoDBAtlasLocalContainer mongo = new MongoDBAtlasLocalContainer("mongodb/mongodb-atlas-local:7.0.9"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class) + .withConfiguration(AutoConfigurations.of(MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, + MongoDBAtlasVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withPropertyValues("spring.data.mongodb.database=springaisample", + "spring.ai.vectorstore.mongodb.initialize-schema=true", + "spring.ai.vectorstore.mongodb.collection-name=test_collection", + // "spring.ai.vectorstore.mongodb.path-name=testembedding", + "spring.ai.vectorstore.mongodb.index-name=text_index", + "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"), + String.format("spring.data.mongodb.uri=" + mongo.getConnectionString())); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -73,27 +87,14 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { "Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers Testcontainers", Collections.singletonMap("foo", "baz"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class) - .withConfiguration(AutoConfigurations.of(MongoAutoConfiguration.class, MongoDataAutoConfiguration.class, - MongoDBAtlasVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class, - SpringAiRetryAutoConfiguration.class, OpenAiAutoConfiguration.class)) - .withPropertyValues("spring.data.mongodb.database=springaisample", - "spring.ai.vectorstore.mongodb.initialize-schema=true", - "spring.ai.vectorstore.mongodb.collection-name=test_collection", - // "spring.ai.vectorstore.mongodb.path-name=testembedding", - "spring.ai.vectorstore.mongodb.index-name=text_index", - "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"), - String.format("spring.data.mongodb.uri=" + mongo.getConnectionString())); - @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.MONGODB, VectorStoreObservationContext.Operation.ADD); observationRegistry.clear(); @@ -104,7 +105,7 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsEntry("meta2", "meta2"); @@ -114,7 +115,7 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); + vectorStore.delete(this.documents.stream().map(Document::getId).collect(Collectors.toList())); assertObservationRegistry(observationRegistry, VectorStoreProvider.MONGODB, VectorStoreObservationContext.Operation.DELETE); @@ -129,29 +130,32 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { @Test public void addAndSearchWithFilters() { - contextRunner.withPropertyValues("spring.ai.vectorstore.mongodb.metadata-fields-to-filter=foo").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vectorstore.mongodb.metadata-fields-to-filter=foo") + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(this.documents); - Thread.sleep(5000); // Await a second for the document to be indexed + Thread.sleep(5000); // Await a second for the document to be indexed - List results = vectorStore.similaritySearch(SearchRequest.query("Testcontainers").withTopK(2)); - assertThat(results).hasSize(2); - results.forEach(doc -> assertThat(doc.getContent().contains("Testcontainers")).isTrue()); + List results = vectorStore + .similaritySearch(SearchRequest.query("Testcontainers").withTopK(2)); + assertThat(results).hasSize(2); + results.forEach(doc -> assertThat(doc.getContent().contains("Testcontainers")).isTrue()); - FilterExpressionBuilder b = new FilterExpressionBuilder(); - results = vectorStore.similaritySearch( - SearchRequest.query("Testcontainers").withTopK(2).withFilterExpression(b.eq("foo", "bar").build())); + FilterExpressionBuilder b = new FilterExpressionBuilder(); + results = vectorStore.similaritySearch(SearchRequest.query("Testcontainers") + .withTopK(2) + .withFilterExpression(b.eq("foo", "bar").build())); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(3).getId()); - assertThat(resultDoc.getContent().contains("Testcontainers")).isTrue(); - assertThat(resultDoc.getMetadata()).containsEntry("foo", "bar"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(3).getId()); + assertThat(resultDoc.getContent().contains("Testcontainers")).isTrue(); + assertThat(resultDoc.getMetadata()).containsEntry("foo", "bar"); - context.getBean(MongoTemplate.class).dropCollection("test_collection"); - }); + context.getBean(MongoTemplate.class).dropCollection("test_collection"); + }); } @Configuration(proxyBeanMethods = false) @@ -164,4 +168,4 @@ class MongoDBAtlasVectorStoreAutoConfigurationIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java index 3f56ca284..09b12c83a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.neo4j; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.neo4j; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +39,9 @@ import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.Neo4jContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Jingzhou Ou @@ -57,11 +58,6 @@ public class Neo4jVectorStoreAutoConfigurationIT { static Neo4jContainer neo4jContainer = new Neo4jContainer<>(DockerImageName.parse("neo4j:5.18")) .withRandomPassword(); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(Neo4jAutoConfiguration.class, Neo4jVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -69,9 +65,14 @@ public class Neo4jVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.neo4j.initialize-schema=true", "spring.neo4j.authentication.username=" + "neo4j", "spring.neo4j.authentication.password=" + neo4jContainer.getAdminPassword()); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.neo4j.label=my_test_label", "spring.ai.vectorstore.neo4j.embeddingDimension=384", "spring.ai.vectorstore.neo4j.indexName=customIndexName") @@ -84,7 +85,7 @@ public class Neo4jVectorStoreAutoConfigurationIT { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.NEO4J, VectorStoreObservationContext.Operation.ADD); @@ -94,7 +95,7 @@ public class Neo4jVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -103,7 +104,7 @@ public class Neo4jVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.NEO4J, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java index 01dee1026..b0e4cdafa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java @@ -1,27 +1,28 @@ /* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.vectorstore.observation; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - /** * @author Christian Tzolov * @since 1.0.0 diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java index 0fdd547ab..29b8a3878 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.observation; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.observation; import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext; import io.micrometer.tracing.otel.bridge.OtelTracer; import io.opentelemetry.api.OpenTelemetry; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationFilter; import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Unit tests for {@link VectorStoreObservationAutoConfiguration}. * @@ -38,21 +40,21 @@ class VectorStoreObservationAutoConfigurationTests { @Test void queryResponseFilterDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationFilter.class); }); } @Test void queryResponseHandlerDefault() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationHandler.class); }); } @Test void queryResponseHandlerEnabled() { - contextRunner + this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.vectorstore.observations.include-query-response=true") .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java index b0f6e0ab1..7b23132d7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import com.jayway.jsonpath.JsonPath; import net.minidev.json.JSONArray; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -31,16 +43,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.localstack.LocalStackContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; @@ -56,11 +58,6 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class)) @@ -86,6 +83,11 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { } """); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @BeforeAll static void beforeAll() throws IOException, InterruptedException { String[] createDomainCmd = { "awslocal", "opensearch", "create-domain", "--domain-name", @@ -109,7 +111,7 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { this.contextRunner.run(context -> { OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -121,14 +123,14 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java index e7c584621..5445f6de7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,20 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.opensearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.opensearch; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -41,13 +45,10 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; -import io.micrometer.observation.tck.TestObservationRegistry; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.regions.Region; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @Testcontainers class OpenSearchVectorStoreAutoConfigurationIT { @@ -58,11 +59,6 @@ class OpenSearchVectorStoreAutoConfigurationIT { private static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; - private List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class)) @@ -82,6 +78,11 @@ class OpenSearchVectorStoreAutoConfigurationIT { } """); + private List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @Test public void addAndSearchTest() { @@ -89,7 +90,7 @@ class OpenSearchVectorStoreAutoConfigurationIT { OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.OPENSEARCH, VectorStoreObservationContext.Operation.ADD); @@ -111,14 +112,14 @@ class OpenSearchVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.OPENSEARCH, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java index 6f9952ccc..078c381d2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.oracle; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.oracle; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -38,12 +42,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -58,11 +59,6 @@ public class OracleVectorStoreAutoConfigurationIT { .withCopyFileToContainer(MountableFile.forClasspathResource("/oracle/initialize.sql"), "/container-entrypoint-initdb.d/initialize.sql"); - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(OracleVectorStoreAutoConfiguration.class, JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) @@ -76,14 +72,29 @@ public class OracleVectorStoreAutoConfigurationIT { String.format("spring.datasource.password=%s", oracle23aiContainer.getPassword()), "spring.datasource.type=oracle.jdbc.pool.OracleDataSource"); + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, VectorStoreObservationContext.Operation.ADD); @@ -94,7 +105,7 @@ public class OracleVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, @@ -102,7 +113,7 @@ public class OracleVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.ORACLE, VectorStoreObservationContext.Operation.DELETE); @@ -113,16 +124,6 @@ public class OracleVectorStoreAutoConfigurationIT { }); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java index d7e00a9ef..b57df3892 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.oracle; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.OracleVectorStore; import org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType; import org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreIndexType; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java index 6777b04e1..25a969fce 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pgvector; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.pgvector; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -42,11 +45,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -61,6 +62,18 @@ public class PgVectorStoreAutoConfigurationIT { @SuppressWarnings("resource") static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg16"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PgVectorStoreAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + "spring.ai.vectorstore.pgvector.initialize-schema=true", + // JdbcTemplate configuration + String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), + "spring.datasource.username=" + postgresContainer.getUsername(), + "spring.datasource.password=" + postgresContainer.getPassword()); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -76,22 +89,17 @@ public class PgVectorStoreAutoConfigurationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(PgVectorStoreAutoConfiguration.class, - JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - "spring.ai.vectorstore.pgvector.initialize-schema=true", - // JdbcTemplate configuration - String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), - "spring.datasource.username=" + postgresContainer.getUsername(), - "spring.datasource.password=" + postgresContainer.getPassword()); + private static boolean isFullyQualifiedTableExists(ApplicationContext context, String schemaName, + String tableName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + String sql = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)"; + return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName, tableName); + } @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { PgVectorStore vectorStore = context.getBean(PgVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); @@ -100,7 +108,7 @@ public class PgVectorStoreAutoConfigurationIT { PgVectorStore.DEFAULT_TABLE_NAME)) .isTrue(); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, VectorStoreObservationContext.Operation.ADD); @@ -111,7 +119,7 @@ public class PgVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, @@ -119,7 +127,7 @@ public class PgVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.PG_VECTOR, VectorStoreObservationContext.Operation.DELETE); @@ -136,7 +144,7 @@ public class PgVectorStoreAutoConfigurationIT { String schemaName = schemaTableName.split(":")[0]; String tableName = schemaTableName.split(":")[1]; - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName) .run(context -> { @@ -150,7 +158,7 @@ public class PgVectorStoreAutoConfigurationIT { String schemaName = schemaTableName.split(":")[0]; String tableName = schemaTableName.split(":")[1]; - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName, "spring.ai.vectorstore.pgvector.initialize-schema=false") @@ -174,11 +182,4 @@ public class PgVectorStoreAutoConfigurationIT { } - private static boolean isFullyQualifiedTableExists(ApplicationContext context, String schemaName, - String tableName) { - JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); - String sql = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)"; - return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName, tableName); - } - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java index a4e4ddef3..f52f4ce78 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pgvector; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.pgvector; import org.junit.jupiter.api.Test; @@ -23,6 +22,8 @@ import org.springframework.ai.vectorstore.PgVectorStore; import org.springframework.ai.vectorstore.PgVectorStore.PgDistanceType; import org.springframework.ai.vectorstore.PgVectorStore.PgIndexType; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java index 461e02abb..e5cc9f97d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pinecone; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,10 +23,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.tck.TestObservationRegistry; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -43,7 +42,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -53,6 +54,16 @@ import io.micrometer.observation.tck.TestObservationRegistry; @EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+") public class PineconeVectorStoreAutoConfigurationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PineconeVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.pinecone.apiKey=" + System.getenv("PINECONE_API_KEY"), + "spring.ai.vectorstore.pinecone.environment=gcp-starter", + "spring.ai.vectorstore.pinecone.projectId=814621f", + "spring.ai.vectorstore.pinecone.indexName=spring-ai-test-index", + "spring.ai.vectorstore.pinecone.contentFieldName=customContentField", + "spring.ai.vectorstore.pinecone.distanceMetadataFieldName=customDistanceField"); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -68,16 +79,6 @@ public class PineconeVectorStoreAutoConfigurationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(PineconeVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.pinecone.apiKey=" + System.getenv("PINECONE_API_KEY"), - "spring.ai.vectorstore.pinecone.environment=gcp-starter", - "spring.ai.vectorstore.pinecone.projectId=814621f", - "spring.ai.vectorstore.pinecone.indexName=spring-ai-test-index", - "spring.ai.vectorstore.pinecone.contentFieldName=customContentField", - "spring.ai.vectorstore.pinecone.distanceMetadataFieldName=customDistanceField"); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -88,12 +89,12 @@ public class PineconeVectorStoreAutoConfigurationIT { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { PineconeVectorStore vectorStore = context.getBean(PineconeVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.ADD); @@ -107,7 +108,7 @@ public class PineconeVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -118,7 +119,7 @@ public class PineconeVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java index ce006a438..ce450bac0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.pinecone; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vectorstore.pinecone; import java.time.Duration; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.PineconeVectorStore; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java index 99525be0e..6a418ab18 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.qdrant; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.qdrant; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; @@ -36,11 +39,9 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -55,11 +56,6 @@ public class QdrantVectorStoreAutoConfigurationIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.9.2"); - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -67,14 +63,29 @@ public class QdrantVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.qdrant.initialize-schema=true", "spring.ai.vectorstore.qdrant.host=" + qdrantContainer.getHost()); + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.QDRANT, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +96,7 @@ public class QdrantVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); assertObservationRegistry(observationRegistry, VectorStoreProvider.QDRANT, @@ -93,7 +104,7 @@ public class QdrantVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); assertThat(results).hasSize(0); @@ -103,16 +114,6 @@ public class QdrantVectorStoreAutoConfigurationIT { }); } - public static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java index 34e194c67..2358a821e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreCloudAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import java.io.IOException; @@ -67,6 +68,15 @@ public class QdrantVectorStoreCloudAutoConfigurationIT { // NOTE: The GRPC port (usually 6334) is different from the HTTP port (usually 6333)! private static final int CLOUD_GRPC_PORT = 6334; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.qdrant.port=" + CLOUD_GRPC_PORT, + "spring.ai.vectorstore.qdrant.host=" + CLOUD_HOST, + "spring.ai.vectorstore.qdrant.api-key=" + CLOUD_API_KEY, + "spring.ai.vectorstore.qdrant.collection-name=" + COLLECTION_NAME, + "spring.ai.vectorstore.qdrant.initializeSchema=true", "spring.ai.vectorstore.qdrant.use-tls=true"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -92,38 +102,6 @@ public class QdrantVectorStoreCloudAutoConfigurationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(QdrantVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.ai.vectorstore.qdrant.port=" + CLOUD_GRPC_PORT, - "spring.ai.vectorstore.qdrant.host=" + CLOUD_HOST, - "spring.ai.vectorstore.qdrant.api-key=" + CLOUD_API_KEY, - "spring.ai.vectorstore.qdrant.collection-name=" + COLLECTION_NAME, - "spring.ai.vectorstore.qdrant.initializeSchema=true", "spring.ai.vectorstore.qdrant.use-tls=true"); - - @Test - public void addAndSearch() { - contextRunner.run(context -> { - - VectorStore vectorStore = context.getBean(VectorStore.class); - - vectorStore.add(documents); - - List results = vectorStore - .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); - - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - assertThat(results).hasSize(0); - }); - } - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -134,6 +112,29 @@ public class QdrantVectorStoreCloudAutoConfigurationIT { } } + @Test + public void addAndSearch() { + this.contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(this.documents); + + List results = vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results).hasSize(0); + }); + } + @Configuration(proxyBeanMethods = false) static class Config { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java index 878f298ab..31c3e5282 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java index 634447464..41295f128 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.redis; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.redis; import java.util.List; import java.util.Map; +import com.redis.testcontainers.RedisStackContainer; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,12 +38,9 @@ import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.redis.testcontainers.RedisStackContainer; - -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Julien Ruaux @@ -56,11 +56,6 @@ class RedisVectorStoreAutoConfigurationIT { static RedisStackContainer redisContainer = new RedisStackContainer( RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); - List documents = List.of( - new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), - new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( - ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class, RedisVectorStoreAutoConfiguration.class)) .withUserConfiguration(Config.class) @@ -69,13 +64,18 @@ class RedisVectorStoreAutoConfigurationIT { .withPropertyValues("spring.ai.vectorstore.redis.index=myIdx") .withPropertyValues("spring.ai.vectorstore.redis.prefix=doc:"); + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.REDIS, VectorStoreObservationContext.Operation.ADD); @@ -85,7 +85,7 @@ class RedisVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); @@ -94,7 +94,7 @@ class RedisVectorStoreAutoConfigurationIT { observationRegistry.clear(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.REDIS, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java index cb3691031..0b38b40c7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStorePropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.redis; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Julien Ruaux * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java index aea63b614..54aaaf8a1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vectorstore.typesense; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +package org.springframework.ai.autoconfigure.vectorstore.typesense; import java.time.Duration; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,11 +38,9 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.tck.TestObservationRegistry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Pablo Sanchidrian Herrera @@ -57,18 +58,18 @@ public class TypesenseVectorStoreAutoConfigurationIT { .withCommand("--data-dir", "/tmp", "--api-key=xyz", "--enable-cors") .withStartupTimeout(Duration.ofSeconds(100)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(TypesenseVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class); + List documents = List.of( new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(TypesenseVectorStoreAutoConfiguration.class)) - .withUserConfiguration(Config.class); - @Test public void addAndSearch() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vectorstore.typesense.embeddingDimension=384", "spring.ai.vectorstore.typesense.collectionName=myTestCollection", "spring.ai.vectorstore.typesense.initialize-schema=true", @@ -80,7 +81,7 @@ public class TypesenseVectorStoreAutoConfigurationIT { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); assertObservationRegistry(observationRegistry, VectorStoreProvider.TYPESENSE, VectorStoreObservationContext.Operation.ADD); @@ -90,7 +91,7 @@ public class TypesenseVectorStoreAutoConfigurationIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); @@ -100,7 +101,7 @@ public class TypesenseVectorStoreAutoConfigurationIT { VectorStoreObservationContext.Operation.QUERY); observationRegistry.clear(); - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); assertObservationRegistry(observationRegistry, VectorStoreProvider.TYPESENSE, VectorStoreObservationContext.Operation.DELETE); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java index ba81c7479..02d3b2a3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import java.util.List; import java.util.Map; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.Test; import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -35,9 +38,6 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.testcontainers.weaviate.WeaviateContainer; - -import io.micrometer.observation.tck.TestObservationRegistry; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; @@ -68,7 +68,7 @@ public class WeaviateVectorStoreAutoConfigurationIT { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { WeaviateVectorStoreProperties properties = context.getBean(WeaviateVectorStoreProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java index 3cb3481a2..d5397b56a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.embedding; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.embedding; import java.io.File; import java.util.List; @@ -23,6 +22,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.io.TempDir; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingRequest; import org.springframework.ai.embedding.EmbeddingOptions; @@ -33,6 +33,8 @@ import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Christian Tzolov */ @@ -40,17 +42,17 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class VertexAiTextEmbeddingModelAutoConfigurationIT { - @TempDir - File tempDir; - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.vertex.ai.embedding.project-id=" + System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"), "spring.ai.vertex.ai.embedding.location=" + System.getenv("VERTEX_AI_GEMINI_LOCATION")) .withConfiguration(AutoConfigurations.of(VertexAiEmbeddingAutoConfiguration.class)); + @TempDir + File tempDir; + @Test public void textEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); var textEmbeddingProperties = context.getBean(VertexAiTextEmbeddingProperties.class); @@ -69,17 +71,17 @@ public class VertexAiTextEmbeddingModelAutoConfigurationIT { @Test void textEmbeddingActivation() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); }); @@ -88,7 +90,7 @@ public class VertexAiTextEmbeddingModelAutoConfigurationIT { @Test public void multimodalEmbedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); var multimodalEmbeddingProperties = context.getBean(VertexAiMultimodalEmbeddingProperties.class); @@ -122,17 +124,17 @@ public class VertexAiTextEmbeddingModelAutoConfigurationIT { @Test void multimodalEmbeddingActivation() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isEmpty(); }); - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); }); - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java index 7ab1ef34f..a9d9716a5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini; import java.util.stream.Collectors; @@ -21,12 +22,12 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -45,7 +46,7 @@ public class VertexAiGeminiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -55,7 +56,7 @@ public class VertexAiGeminiAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index 66a8c0113..4d17b12cf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; import java.util.function.Function; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -37,6 +37,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") class FunctionCallWithFunctionBeanIT { @@ -53,7 +55,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -69,21 +71,21 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); @@ -93,7 +95,7 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -109,14 +111,14 @@ class FunctionCallWithFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -142,4 +144,4 @@ class FunctionCallWithFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java index 08a095119..34688fcef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -37,6 +37,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class FunctionCallWithFunctionWrapperIT { @@ -51,7 +53,7 @@ public class FunctionCallWithFunctionWrapperIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -66,7 +68,7 @@ public class FunctionCallWithFunctionWrapperIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -87,4 +89,4 @@ public class FunctionCallWithFunctionWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java index e72ac4480..2cde310ab 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.autoconfigure.vertexai.gemini.tool; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; @@ -23,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -34,6 +34,8 @@ import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import static org.assertj.core.api.Assertions.assertThat; + @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") public class FunctionCallWithPromptFunctionIT { @@ -47,7 +49,7 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) .run(context -> { @@ -75,7 +77,7 @@ public class FunctionCallWithPromptFunctionIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -83,11 +85,11 @@ public class FunctionCallWithPromptFunctionIT { response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java index ed34d7b0c..aa78f7594 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.function.Function; @@ -31,14 +32,21 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; @JsonClassDescription("Get the weather in location") public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -66,28 +74,23 @@ public class MockWeatherService implements Function { + this.contextRunner.run(context -> { VertexAiPaLm2ChatModel chatModel = context.getBean(VertexAiPaLm2ChatModel.class); String response = chatModel.call("Hello"); @@ -56,7 +58,7 @@ public class VertexAiPaLm2AutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VertexAiPaLm2EmbeddingModel embeddingModel = context.getBean(VertexAiPaLm2EmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -75,19 +77,19 @@ public class VertexAiPaLm2AutoConfigurationIT { public void embeddingActivation() { // Disable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isEmpty(); }); // The embedding auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isNotEmpty(); }); // Explicitly enable the embedding auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiPalm2EmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2EmbeddingModel.class)).isNotEmpty(); }); @@ -97,19 +99,19 @@ public class VertexAiPaLm2AutoConfigurationIT { public void chatActivation() { // Disable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=false").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=false").run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isEmpty(); }); // The chat auto-configuration is enabled by default. - contextRunner.run(context -> { + this.contextRunner.run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isNotEmpty(); }); // Explicitly enable the chat auto-configuration. - contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.chat.enabled=true").run(context -> { assertThat(context.getBeansOfType(VertexAiPlam2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(VertexAiPaLm2ChatModel.class)).isNotEmpty(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java index 1637b204e..049fd61c7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.watsonxai; import org.junit.Test; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -51,4 +53,4 @@ public class WatsonxAiAutoConfigurationTests { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java index b9ec4e45f..f15f82f82 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; +import java.util.List; +import java.util.stream.Collectors; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -32,10 +38,6 @@ import org.springframework.ai.zhipuai.ZhiPuAiImageModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class ZhiPuAiAutoConfigurationIT { @Test void generate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); String response = chatModel.call("Hello"); assertThat(response).isNotEmpty(); @@ -64,7 +66,7 @@ public class ZhiPuAiAutoConfigurationIT { @Test void generateStreaming() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { @@ -78,7 +80,7 @@ public class ZhiPuAiAutoConfigurationIT { @Test void embedding() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ZhiPuAiEmbeddingModel embeddingModel = context.getBean(ZhiPuAiEmbeddingModel.class); EmbeddingResponse embeddingResponse = embeddingModel @@ -95,7 +97,7 @@ public class ZhiPuAiAutoConfigurationIT { @Test void generateImage() { - contextRunner.withPropertyValues("spring.ai.zhipuai.image.options.size=1024x1024").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.image.options.size=1024x1024").run(context -> { ZhiPuAiImageModel ImageModel = context.getBean(ZhiPuAiImageModel.class); ImageResponse imageResponse = ImageModel.call(new ImagePrompt("forest")); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java index 2aaf62946..2cc2c9b21 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai; import org.junit.jupiter.api.Test; import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.zhipuai.ZhiPuAiChatModel; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java index 8dc63f205..ca91b63c3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -32,10 +38,6 @@ import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +56,7 @@ public class FunctionCallbackInPromptIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -71,7 +73,7 @@ public class FunctionCallbackInPromptIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -80,7 +82,7 @@ public class FunctionCallbackInPromptIT { @Test void streamingFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -105,7 +107,7 @@ public class FunctionCallbackInPromptIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -113,4 +115,4 @@ public class FunctionCallbackInPromptIT { }); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 7b440c381..5b2657c69 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,11 +43,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Description; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +62,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -71,7 +73,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -79,7 +81,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -88,7 +90,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -102,13 +104,13 @@ class FunctionCallbackWithPlainFunctionBeanIT { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); }); } @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -127,7 +129,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -145,7 +147,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -172,4 +174,4 @@ class FunctionCallbackWithPlainFunctionBeanIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java index 2596f3b84..9016104f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.chat.messages.AssistantMessage; @@ -35,10 +41,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import reactor.core.publisher.Flux; - -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -58,7 +60,7 @@ public class FunctionCallbackWrapperIT { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -68,7 +70,7 @@ public class FunctionCallbackWrapperIT { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("WeatherInfo").build())); - logger.info("Response: {}", response); + this.logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -77,7 +79,7 @@ public class FunctionCallbackWrapperIT { @Test void streamFunctionCallTest() { - contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); @@ -95,7 +97,7 @@ public class FunctionCallbackWrapperIT { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - logger.info("Response: {}", content); + this.logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -119,4 +121,4 @@ public class FunctionCallbackWrapperIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java index 61f6d6c2d..75d562648 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.zhipuai.tool; +import java.util.function.Function; + import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import java.util.function.Function; - /** * Mock 3rd party weather service. * @@ -30,16 +31,21 @@ import java.util.function.Function; */ public class MockWeatherService implements Function { - /** - * Weather Function request. - */ - @JsonInclude(Include.NON_NULL) - @JsonClassDescription("Weather API request") - public record Request(@JsonProperty(required = true, - value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, - @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); } /** @@ -67,28 +73,25 @@ public class MockWeatherService implements Function + + 4.0.0 diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java index b861bea6a..7d3f650aa 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java index ddfba20fb..371ed7c9a 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironment.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java index 8de8fb4ec..7c04ff256 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.mongo; import com.mongodb.ConnectionString; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.docker.compose.core.RunningService; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionDetailsFactory; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java index db84ed581..6608cfa0a 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.ollama; import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java index 718112524..0fbfe2088 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; +import java.util.List; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.docker.compose.core.RunningService; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionDetailsFactory; import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionSource; -import java.util.List; - /** * @author Eddú Meléndez */ diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java index 56adc0afe..034ba169f 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironment.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java index c73781dac..2de9b0a4b 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java index 8752ad38f..8005a0ed2 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantEnvironment.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java index 9de92e613..ad7691716 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.typesense; import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java index 139815a5d..b8b70e449 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironment.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.typesense; import java.util.Map; diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java index 6c88ff2c1..2f8216b7d 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.weaviate; import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; diff --git a/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories b/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories index cf9041575..fcc2bfdc3 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-boot-docker-compose/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory=\ org.springframework.ai.docker.compose.service.connection.chroma.ChromaDockerComposeConnectionDetailsFactory,\ org.springframework.ai.docker.compose.service.connection.mongo.MongoDbAtlasLocalDockerComposeConnectionDetailsFactory,\ diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java index 5c015d862..66e670c46 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java index df3e514d0..37d416d13 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.chroma; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.chroma; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class ChromaEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java index 493f8ca54..8795dc4f6 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.chroma; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java index b312d71b4..c88bf10cd 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java @@ -1,9 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.docker.compose.service.connection.mongo; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java index 7f88d6a35..9b7282819 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/ollama/OllamaDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.ollama; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java index e50a8655d..a162d05d2 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.opensearch; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java index 7e7ba6c42..c14572327 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.opensearch; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.opensearch; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class OpenSearchEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java index bc907baaa..7dd990bc0 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/qdrant/QdrantDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.qdrant; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java index a0c3925fe..b766c31cb 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.typesense; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java index 9b65f53c4..9ed0920db 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseEnvironmentTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.docker.compose.service.connection.typesense; -import org.junit.jupiter.api.Test; +package org.springframework.ai.docker.compose.service.connection.typesense; import java.util.Collections; import java.util.Map; +import org.junit.jupiter.api.Test; + import static org.assertj.core.api.Assertions.assertThat; class TypesenseEnvironmentTests { diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java index d046d8fcb..58f6a5a49 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/weaviate/WeaviateDockerComposeConnectionDetailsFactoryTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.docker.compose.service.connection.weaviate; import org.junit.jupiter.api.Test; +import org.testcontainers.utility.DockerImageName; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests; -import org.testcontainers.utility.DockerImageName; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java index b1279bd76..8c6289c87 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,14 +26,14 @@ import java.util.Map; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.io.TempDir; -import org.springframework.boot.autoconfigure.ImportAutoConfiguration; -import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration; -import org.springframework.boot.testsupport.DisabledIfProcessUnavailable; import org.testcontainers.utility.DockerImageName; import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplicationShutdownHandlers; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; +import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration; +import org.springframework.boot.testsupport.DisabledIfProcessUnavailable; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; @@ -60,17 +60,17 @@ public abstract class AbstractDockerComposeIntegrationTests { private final DockerImageName dockerImageName; + protected AbstractDockerComposeIntegrationTests(String composeResource, DockerImageName dockerImageName) { + this.composeResource = new ClassPathResource(composeResource, getClass()); + this.dockerImageName = dockerImageName; + } + @AfterAll static void shutDown() { SpringApplicationShutdownHandlers shutdownHandlers = SpringApplication.getShutdownHandlers(); ((Runnable) shutdownHandlers).run(); } - protected AbstractDockerComposeIntegrationTests(String composeResource, DockerImageName dockerImageName) { - this.composeResource = new ClassPathResource(composeResource, getClass()); - this.dockerImageName = dockerImageName; - } - protected final T run(Class type) { SpringApplication application = new SpringApplication(Config.class); Map properties = new LinkedHashMap<>(); diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java index aded7c820..857f2d6d9 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailable.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,6 @@ package org.springframework.boot.testsupport; -import org.junit.jupiter.api.extension.ExtendWith; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Repeatable; @@ -25,6 +23,8 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + /** * Disables test execution if a process is unavailable. * diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java index 98690a1cc..62dd2f20e 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,22 @@ package org.springframework.boot.testsupport; -import org.junit.jupiter.api.extension.ConditionEvaluationResult; -import org.junit.jupiter.api.extension.ExecutionCondition; -import org.junit.jupiter.api.extension.ExtensionContext; -import org.springframework.core.annotation.MergedAnnotation; -import org.springframework.core.annotation.MergedAnnotations; -import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; - import java.lang.reflect.AnnotatedElement; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; +import org.junit.jupiter.api.extension.ConditionEvaluationResult; +import org.junit.jupiter.api.extension.ExecutionCondition; +import org.junit.jupiter.api.extension.ExtensionContext; + +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + /** * An {@link ExecutionCondition} that disables execution if specified processes cannot * start. diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java index b62bf3177..bfc2e88f2 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailables.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,14 @@ package org.springframework.boot.testsupport; -import org.junit.jupiter.api.extension.ExtendWith; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + /** * Repeatable container for {@link DisabledIfProcessUnavailable}. * diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml index 91584cb45..31e711ba2 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-anthropic/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml index dc5d637b9..186f0bad8 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml index ac8c2ffb1..e2c3ec1cc 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml index 61a3a6d9d..7ccc8ef96 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml index d9292532e..f5face502 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-azure-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml index 781dc92a6..fc95593d5 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml index 00b58ba75..968e4fa53 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-cassandra-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml index 6dd309936..f9e588f8b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-chroma-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml index 363ca6d11..efdba43b9 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml index de48a933f..b393648d6 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml index bba6eddce..0309e3970 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml index 18d9e2c09..0832116aa 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-huggingface/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml index 825e9f030..49c66a4c2 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-milvus-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml index 3004b42a9..124570853 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-minimax/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml index 05a617914..4c00f8326 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml index 52cb6d058..7f868e68a 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml index 2c2b19a84..70c34367b 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-moonshot/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml index cc2941114..f0203e697 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml index 2fd347de3..f8a835287 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-oci-genai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml index fc3558403..7b16c0c67 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-ollama/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml index 95b60e642..a5bce9888 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-openai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml index c97eb81ad..07533074d 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml index 210cb8301..72d2cde2f 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-oracle-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml index a194100e6..141e9f2a1 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-pgvector-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml index aefe6a0a6..bb1977f6d 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-pinecone-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml index 0378542fa..ffbc4aa55 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml index acbfe9a28..fddb080fe 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-qdrant-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml index e8a312467..0da3b1776 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml index 09fe4abb2..637b3e442 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-redis-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml index b2a74b166..64764552e 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-stability-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml index 5cab1733d..9b0e61643 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-transformers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml index de5170aa8..2b3e52518 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-typesense-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml index f59c35330..36fcee4d1 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml index 3ef367efa..778b9cccd 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml index eec7a61c0..97d0f1077 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml index 44dc3af3f..36e871422 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml index ea237a2e1..e46246176 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-weaviate-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml index 060ac2908..9eb4283c1 100644 --- a/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml +++ b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-testcontainers/pom.xml b/spring-ai-spring-boot-testcontainers/pom.xml index a56a7af7a..d1d5d1485 100644 --- a/spring-ai-spring-boot-testcontainers/pom.xml +++ b/spring-ai-spring-boot-testcontainers/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java index 5efb6b1af..909d43f7a 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.Map; + +import org.testcontainers.chromadb.ChromaDBContainer; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.chromadb.ChromaDBContainer; - -import java.util.Map; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java index a44643c38..137f85296 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.milvus.MilvusContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java index 8bc4b2021..bf425a33c 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; import com.mongodb.ConnectionString; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; /** * A {@link ContainerConnectionDetailsFactory} implementation that provides diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java index 46174bc36..b800be8b7 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.autoconfigure.ollama.OllamaConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.ollama.OllamaContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java index 22898785f..a154ddaf0 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.opensearch; +import java.util.List; + import org.opensearch.testcontainers.OpensearchContainer; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import java.util.List; - /** * @author Eddú Meléndez */ diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java index 619793d1b..e6a6bd536 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.qdrant.QdrantContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java index 5779b3e53..68769925a 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.typesense; +import org.testcontainers.containers.Container; + import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.containers.Container; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java index 601fb6244..d953dfa9c 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-testcontainers/src/main/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateConnectionDetails; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory; import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource; -import org.testcontainers.weaviate.WeaviateContainer; /** * @author Eddú Meléndez diff --git a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories index bf8370e68..a4d88b8ef 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory=\ org.springframework.ai.testcontainers.service.connection.chroma.ChromaContainerConnectionDetailsFactory,\ org.springframework.ai.testcontainers.service.connection.milvus.MilvusContainerConnectionDetailsFactory,\ diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java index 2bf14e3de..2c387f2c1 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 2eef54949..efd038e5d 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java index 1c7daf642..86c9e2718 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java index 34460d6fb..fbe935032 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java @@ -15,8 +15,15 @@ */ package org.springframework.ai.testcontainers.service.connection.chroma; +import java.util.List; +import java.util.Map; + import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,12 +37,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java index c4aed4e34..2f3f51910 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; @@ -30,12 +38,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -61,22 +63,22 @@ class MilvusContainerConnectionDetailsFactoryTest { @Test public void addAndSearch() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()) .contains("Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); } @@ -91,4 +93,4 @@ class MilvusContainerConnectionDetailsFactoryTest { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java index e125d08ba..168e854e7 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.milvus; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java index 34b748de4..1f05b5d0a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -32,13 +41,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; - -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -68,10 +70,10 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - vectorStore.add(documents); + this.vectorStore.add(documents); Thread.sleep(5000); // Await a second for the document to be indexed - List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -81,9 +83,9 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT { assertThat(resultDoc.getMetadata()).containsEntry("meta2", "meta2"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); + this.vectorStore.delete(documents.stream().map(Document::getId).collect(Collectors.toList())); - List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).isEmpty(); } @@ -99,4 +101,4 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java index 3ac2aa284..af0cb68f5 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.mongo; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java index 6e139b747..8cb9f4a12 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; +import java.io.IOException; +import java.util.List; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.OllamaEmbeddingModel; @@ -30,12 +38,6 @@ import org.springframework.boot.testcontainers.service.connection.ServiceConnect import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.ollama.OllamaContainer; - -import java.io.IOException; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -50,10 +52,10 @@ import static org.assertj.core.api.Assertions.assertThat; + OllamaContainerConnectionDetailsFactoryTest.MODEL_NAME) class OllamaContainerConnectionDetailsFactoryTest { - private static final Logger logger = LoggerFactory.getLogger(OllamaContainerConnectionDetailsFactoryTest.class); - static final String MODEL_NAME = "nomic-embed-text"; + private static final Logger logger = LoggerFactory.getLogger(OllamaContainerConnectionDetailsFactoryTest.class); + @Container @ServiceConnection static OllamaContainer ollama = new OllamaContainer(OllamaImage.DEFAULT_IMAGE); @@ -82,4 +84,4 @@ class OllamaContainerConnectionDetailsFactoryTest { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java index 59cf49403..c1bce0c70 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.ollama; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java index 43760317e..7e4393beb 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.testcontainers.service.connection.opensearch; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.testcontainers.service.connection.opensearch; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,6 +24,9 @@ import java.util.Map; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -39,8 +40,9 @@ import org.springframework.boot.testcontainers.service.connection.ServiceConnect import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; @SpringBootTest(properties = { "spring.ai.vectorstore.opensearch.index-name=" + OpenSearchContainerConnectionDetailsFactoryTest.DOCUMENT_INDEX, @@ -50,14 +52,14 @@ import org.testcontainers.junit.jupiter.Testcontainers; @Testcontainers class OpenSearchContainerConnectionDetailsFactoryTest { - @Container - @ServiceConnection - private static final OpensearchContainer opensearch = new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE); - static final String DOCUMENT_INDEX = "auto-spring-ai-document-index"; static final String MAPPING_JSON = "{\"properties\":{\"embedding\":{\"type\":\"knn_vector\",\"dimension\":384}}}"; + @Container + @ServiceConnection + private static final OpensearchContainer opensearch = new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE); + private final List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -69,29 +71,29 @@ class OpenSearchContainerConnectionDetailsFactoryTest { @Test public void addAndSearchTest() { - vectorStore.add(documents); + this.vectorStore.add(this.documents); Awaitility.await() - .until(() -> vectorStore + .until(() -> this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(1)); - List results = vectorStore + List results = this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + this.vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore + .until(() -> this.vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), hasSize(0)); } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java index 8c636cbba..26a615e6a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.opensearch; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java index bdc5a7b10..c25773c79 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,14 +40,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -59,24 +61,6 @@ public class QdrantContainerConnectionDetailsFactoryTest { @Autowired private VectorStore vectorStore; - @Test - public void addAndSearch() { - vectorStore.add(documents); - - List results = vectorStore - .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); - - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - assertThat(results).hasSize(0); - } - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -87,6 +71,24 @@ public class QdrantContainerConnectionDetailsFactoryTest { } } + @Test + public void addAndSearch() { + this.vectorStore.add(this.documents); + + List results = this.vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); + + // Remove all documents from the store + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + results = this.vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results).hasSize(0); + } + @Configuration(proxyBeanMethods = false) @ImportAutoConfiguration(QdrantVectorStoreAutoConfiguration.class) static class Config { diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java index 4642ec663..998e220df 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantContainerWithApiKeyConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -30,14 +40,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -59,24 +61,6 @@ public class QdrantContainerWithApiKeyConnectionDetailsFactoryTest { @Autowired private VectorStore vectorStore; - @Test - public void addAndSearch() { - vectorStore.add(documents); - - List results = vectorStore - .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); - - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - assertThat(results).hasSize(0); - } - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -87,6 +71,24 @@ public class QdrantContainerWithApiKeyConnectionDetailsFactoryTest { } } + @Test + public void addAndSearch() { + this.vectorStore.add(this.documents); + + List results = this.vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); + + // Remove all documents from the store + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + results = this.vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results).hasSize(0); + } + @Configuration(proxyBeanMethods = false) @ImportAutoConfiguration(QdrantVectorStoreAutoConfiguration.class) static class Config { diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java index a50b6576a..618e61cfc 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.qdrant; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java index 338b076a7..1aae0793a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseContainerConnectionDetailsFactoryTest.java @@ -1,6 +1,30 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.testcontainers.service.connection.typesense; +import java.time.Duration; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ResourceUtils; import org.springframework.ai.autoconfigure.vectorstore.typesense.TypesenseVectorStoreAutoConfiguration; import org.springframework.ai.document.Document; @@ -15,13 +39,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.time.Duration; -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -51,19 +68,19 @@ class TypesenseContainerConnectionDetailsFactoryTest { @Test public void addAndSearch() { - this.vectorStore.add(documents); + this.vectorStore.add(this.documents); List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()) .contains("Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); - this.vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); @@ -80,4 +97,4 @@ class TypesenseContainerConnectionDetailsFactoryTest { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java index 406596506..bf9982363 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.typesense; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java index 3a35f27d3..b30588e4e 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; +import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreAutoConfiguration; import org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreProperties; import org.springframework.ai.document.Document; @@ -31,13 +40,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; - -import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -63,15 +65,15 @@ class WeaviateContainerConnectionDetailsFactoryTest { @Test public void addAndSearchWithFilters() { - assertThat(properties.getFilterField()).hasSize(4); + assertThat(this.properties.getFilterField()).hasSize(4); - assertThat(properties.getFilterField().get("country")) + assertThat(this.properties.getFilterField().get("country")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.TEXT); - assertThat(properties.getFilterField().get("year")) + assertThat(this.properties.getFilterField().get("year")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.NUMBER); - assertThat(properties.getFilterField().get("active")) + assertThat(this.properties.getFilterField().get("active")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.BOOLEAN); - assertThat(properties.getFilterField().get("price")) + assertThat(this.properties.getFilterField().get("price")) .isEqualTo(WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField.Type.NUMBER); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -79,39 +81,39 @@ class WeaviateContainerConnectionDetailsFactoryTest { var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Netherlands", "price", 1.57, "active", false, "year", 2023)); - vectorStore.add(List.of(bgDocument, nlDocument)); + this.vectorStore.add(List.of(bgDocument, nlDocument)); var request = SearchRequest.query("The World").withTopK(5); - List results = vectorStore.similaritySearch(request); + List results = this.vectorStore.similaritySearch(request); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherlands'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); - results = vectorStore.similaritySearch( + results = this.vectorStore.similaritySearch( request.withSimilarityThresholdAll().withFilterExpression("price > 1.57 && active == true")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("year in [2020, 2023]")); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("year > 2020 && year <= 2023")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); // Remove all documents from the store - vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); } @Configuration(proxyBeanMethods = false) @@ -125,4 +127,4 @@ class WeaviateContainerConnectionDetailsFactoryTest { } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java index cdece3ffa..8157d3c2e 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.weaviate; import org.testcontainers.utility.DockerImageName; diff --git a/spring-ai-spring-cloud-bindings/pom.xml b/spring-ai-spring-cloud-bindings/pom.xml index 2bf03d03c..da3ffe77f 100644 --- a/spring-ai-spring-cloud-bindings/pom.xml +++ b/spring-ai-spring-cloud-bindings/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java index c53fedb35..bc4dc5087 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java index e4cd607b7..67e1da572 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.net.URI; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.net.URI; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java index 3a22a564f..07dabaa3d 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java index 857fde84b..8afe9e393 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java index af98292e6..d00e04961 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package org.springframework.ai.bindings; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java index 8832af47b..bb14eadd6 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.util.Arrays; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.util.Arrays; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java index 1b2160f4d..223210a1b 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessor.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,14 @@ package org.springframework.ai.bindings; +import java.net.URI; +import java.util.Map; + import org.springframework.cloud.bindings.Binding; import org.springframework.cloud.bindings.Bindings; import org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor; import org.springframework.core.env.Environment; -import java.net.URI; -import java.util.Map; - /** * An implementation of {@link BindingsPropertiesProcessor} that detects {@link Binding}s * of type: {@value TYPE}. diff --git a/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories b/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories index 9562cc660..668f29e0e 100644 --- a/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-cloud-bindings/src/main/resources/META-INF/spring.factories @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Binding Properties Factories org.springframework.cloud.bindings.boot.BindingsPropertiesProcessor=\ org.springframework.ai.bindings.ChromaBindingsPropertiesProcessor,\ diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java index d9023c336..885d8f259 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -51,19 +52,19 @@ class ChromaBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new ChromaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.host", "https://example.net"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.port", "8000"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.username", "itsme"); - assertThat(properties).containsEntry("spring.ai.vectorstore.chroma.client.password", "youknowit"); + new ChromaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.host", "https://example.net"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.port", "8000"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.username", "itsme"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.chroma.client.password", "youknowit"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.chroma.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.chroma.enabled".formatted(CONFIG_PATH), "false"); - new ChromaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new ChromaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java index 05d175f3d..0c2d1db35 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,17 +51,17 @@ class MistralAiBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new MistralAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.mistralai.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.mistralai.base-url", "https://my.mistralai.example.net"); + new MistralAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.mistralai.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.mistralai.base-url", "https://my.mistralai.example.net"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.mistralai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.mistralai.enabled".formatted(CONFIG_PATH), "false"); - new MistralAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new MistralAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java index 247c95210..b308fae9a 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -49,16 +50,16 @@ class OllamaBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new OllamaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.ollama.base-url", "https://example.net/ollama:11434"); + new OllamaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.ollama.base-url", "https://example.net/ollama:11434"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.ollama.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.ollama.enabled".formatted(CONFIG_PATH), "false"); - new OllamaBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new OllamaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java index 08225ab67..fdebd11a2 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,17 +51,17 @@ class OpenAiBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new OpenAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.openai.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.openai.base-url", "https://my.openai.example.net"); + new OpenAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.openai.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.openai.base-url", "https://my.openai.example.net"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.openai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.openai.enabled".formatted(CONFIG_PATH), "false"); - new OpenAiBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new OpenAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java index 14f33b9c0..40492754f 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -69,27 +70,29 @@ class TanzuBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new TanzuBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.openai.chat.api-key", "demo"); - assertThat(properties).containsEntry("spring.ai.openai.chat.base-url", "https://my.openai.example.net"); - assertThat(properties).containsEntry("spring.ai.openai.chat.options.model", "llava1.6"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.api-key", "demo2"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.base-url", "https://my.openai2.example.net"); - assertThat(properties).containsEntry("spring.ai.openai.embedding.options.model", "text-embed-large"); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.api-key", "demo"); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.base-url", "https://my.openai.example.net"); + assertThat(this.properties).containsEntry("spring.ai.openai.chat.options.model", "llava1.6"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.api-key", "demo2"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.base-url", + "https://my.openai2.example.net"); + assertThat(this.properties).containsEntry("spring.ai.openai.embedding.options.model", "text-embed-large"); } @Test void propertiesAreMissingModelCapabilities() { - new TanzuBindingsPropertiesProcessor().process(environment, bindingsMissingModelCapabilities, properties); - assertThat(properties).isEmpty(); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindingsMissingModelCapabilities, + this.properties); + assertThat(this.properties).isEmpty(); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.genai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.genai.enabled".formatted(CONFIG_PATH), "false"); - new TanzuBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java index 9d91a8e84..f48638ff2 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,15 +16,16 @@ package org.springframework.ai.bindings; -import org.junit.jupiter.api.Test; -import org.springframework.cloud.bindings.Binding; -import org.springframework.cloud.bindings.Bindings; -import org.springframework.mock.env.MockEnvironment; - import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import org.junit.jupiter.api.Test; + +import org.springframework.cloud.bindings.Binding; +import org.springframework.cloud.bindings.Bindings; +import org.springframework.mock.env.MockEnvironment; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; @@ -50,18 +51,18 @@ class WeaviateBindingsPropertiesProcessorTests { @Test void propertiesAreContributed() { - new WeaviateBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.scheme", "https"); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.host", "example.net:8000"); - assertThat(properties).containsEntry("spring.ai.vectorstore.weaviate.api-key", "demo"); + new WeaviateBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.scheme", "https"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.host", "example.net:8000"); + assertThat(this.properties).containsEntry("spring.ai.vectorstore.weaviate.api-key", "demo"); } @Test void whenDisabledThenPropertiesAreNotContributed() { - environment.setProperty("%s.weaviate.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty("%s.weaviate.enabled".formatted(CONFIG_PATH), "false"); - new WeaviateBindingsPropertiesProcessor().process(environment, bindings, properties); - assertThat(properties).isEmpty(); + new WeaviateBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); + assertThat(this.properties).isEmpty(); } } diff --git a/spring-ai-test/pom.xml b/spring-ai-test/pom.xml index 45cd0df17..3397a92ae 100644 --- a/spring-ai-test/pom.xml +++ b/spring-ai-test/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java index 2f4b01a79..ea056cca3 100644 --- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java +++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.evaluation; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +package org.springframework.ai.evaluation; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.model.ChatModel; @@ -32,6 +31,9 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + public class BasicEvaluationTest { private static final Logger logger = LoggerFactory.getLogger(BasicEvaluationTest.class); @@ -56,23 +58,23 @@ public class BasicEvaluationTest { assertThat(answer).isNotNull(); logger.info("Question: " + question); logger.info("Answer:" + answer); - PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + PromptTemplate userPromptTemplate = new PromptTemplate(this.userEvaluatorResource, Map.of("question", question, "answer", answer)); SystemMessage systemMessage; if (factBased) { - systemMessage = new SystemMessage(qaEvaluatorFactBasedAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorFactBasedAnswerResource); } else { - systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource); + systemMessage = new SystemMessage(this.qaEvaluatorAccurateAnswerResource); } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatModel.call(prompt).getResult().getOutput().getContent(); + String yesOrNo = this.openAiChatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { - SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); + SystemMessage notRelatedSystemMessage = new SystemMessage(this.qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatModel.call(prompt).getResult().getOutput().getContent(); + String reasonForFailure = this.openAiChatModel.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { @@ -81,4 +83,4 @@ public class BasicEvaluationTest { } } -} \ No newline at end of file +} diff --git a/src/checkstyle/checkstyle-header.txt b/src/checkstyle/checkstyle-header.txt new file mode 100644 index 000000000..9c6236680 --- /dev/null +++ b/src/checkstyle/checkstyle-header.txt @@ -0,0 +1,17 @@ +^\Q/*\E$ +^\Q * Copyright \E20\d\d\-20\d\d\Q the original author or authors.\E$ +^\Q *\E$ +^\Q * Licensed under the Apache License, Version 2.0 (the "License");\E$ +^\Q * you may not use this file except in compliance with the License.\E$ +^\Q * You may obtain a copy of the License at\E$ +^\Q *\E$ +^\Q * https://www.apache.org/licenses/LICENSE-2.0\E$ +^\Q *\E$ +^\Q * Unless required by applicable law or agreed to in writing, software\E$ +^\Q * distributed under the License is distributed on an "AS IS" BASIS,\E$ +^\Q * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\E$ +^\Q * See the License for the specific language governing permissions and\E$ +^\Q * limitations under the License.\E$ +^\Q */\E$ +^$ +^.*$ \ No newline at end of file diff --git a/src/checkstyle/checkstyle-suppressions.xml b/src/checkstyle/checkstyle-suppressions.xml new file mode 100644 index 000000000..5f78aa52b --- /dev/null +++ b/src/checkstyle/checkstyle-suppressions.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + diff --git a/src/checkstyle/checkstyle.xml b/src/checkstyle/checkstyle.xml new file mode 100644 index 000000000..0a224dabb --- /dev/null +++ b/src/checkstyle/checkstyle.xml @@ -0,0 +1,201 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml index 0410a18bd..8c1e57ccd 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml +++ b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java index cb5f1df3c..294dffd85 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; - import java.util.Collection; import java.util.Map; import java.util.Optional; @@ -29,6 +26,9 @@ import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; + /** * Converts {@link org.springframework.ai.vectorstore.filter.Filter.Expression} into * Cosmos DB NoSQL API where clauses. @@ -51,7 +51,7 @@ class CosmosDBFilterExpressionConverter extends AbstractFilterExpressionConverte */ private Optional getMetadataField(String name) { String metadataField = name; - return Optional.ofNullable(metadataFields.get(metadataField)); + return Optional.ofNullable(this.metadataFields.get(metadataField)); } @Override diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index 54716b142..8c4697017 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,20 +24,6 @@ import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; - import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosAsyncContainer; import com.azure.cosmos.CosmosAsyncDatabase; @@ -66,10 +52,23 @@ import com.azure.cosmos.util.CosmosPagedFlux; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; - import io.micrometer.observation.ObservationRegistry; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; + /** * @author Theo van Kraay * @author Soby Chacko @@ -81,14 +80,14 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen private final CosmosAsyncClient cosmosClient; - private CosmosAsyncContainer container; - private final EmbeddingModel embeddingModel; private final CosmosDBVectorStoreConfig properties; private final BatchingStrategy batchingStrategy; + private CosmosAsyncContainer container; + public CosmosDBVectorStore(ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient, CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) { @@ -210,7 +209,7 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen CosmosItemOperation operation = CosmosBulkOperations .getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId())); return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID - // with the operation + // with the operation }).toList(); try { @@ -233,7 +232,7 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen String errorMessage = String.format("Duplicate document id: %s", documentId); logger.error(errorMessage); throw new RuntimeException(errorMessage); // Throw an exception - // for status code 409 + // for status code 409 } else { logger.info("Document added with status: {}", statusCode); @@ -307,10 +306,10 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen if (filterExpression != null) { CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter( this.properties.getMetadataFieldsList()); // Use the expression - // directly as - // it handles the - // "metadata" - // fields internally + // directly as + // it handles the + // "metadata" + // fields internally String filterQuery = filterExpressionConverter.convertExpression(filterExpression); queryBuilder.append(" AND ").append(filterQuery); } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java index 244ee8145..96729cbcf 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -51,15 +51,15 @@ public class CosmosDBVectorStoreConfig implements AutoCloseable { this.vectorStoreThroughput = vectorStoreThroughput; } + public String getMetadataFields() { + return this.metadataFields; + } + public void setMetadataFields(String metadataFields) { this.metadataFields = metadataFields; this.metadataFieldsList = List.of(metadataFields.split(",")); } - public String getMetadataFields() { - return this.metadataFields; - } - public List getMetadataFieldsList() { return this.metadataFieldsList; } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index 6c0644443..d8432fa71 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,11 +16,17 @@ package org.springframework.ai.vectorstore; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosClientBuilder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -30,10 +36,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -53,8 +55,8 @@ public class CosmosDBVectorStoreIT { @BeforeEach public void setup() { - contextRunner.run(context -> { - vectorStore = context.getBean(VectorStore.class); + this.contextRunner.run(context -> { + this.vectorStore = context.getBean(VectorStore.class); }); } @@ -66,25 +68,25 @@ public class CosmosDBVectorStoreIT { Document document2 = new Document(UUID.randomUUID().toString(), "Sample content2", Map.of("key2", "value2")); // Add the document to the vector store - vectorStore.add(List.of(document1, document2)); + this.vectorStore.add(List.of(document1, document2)); // create duplicate docs and assert that second one throws exception Document document3 = new Document(document1.getId(), "Sample content3", Map.of("key3", "value3")); - assertThatThrownBy(() -> vectorStore.add(List.of(document3))).isInstanceOf(Exception.class) + assertThatThrownBy(() -> this.vectorStore.add(List.of(document3))).isInstanceOf(Exception.class) .hasMessageContaining("Duplicate document id: " + document1.getId()); // Perform a similarity search - List results = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results).isNotEmpty(); assertThat(results.get(0).getId()).isEqualTo(document1.getId()); // Remove the documents from the vector store - vectorStore.delete(List.of(document1.getId(), document2.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId())); // Perform a similarity search again - List results2 = vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("Sample content").withTopK(1)); // Verify the search results assertThat(results2).isEmpty(); @@ -124,16 +126,16 @@ public class CosmosDBVectorStoreIT { Document document3 = new Document("3", "A document about the US", metadata3); Document document4 = new Document("4", "A document about the US", metadata4); - vectorStore.add(List.of(document1, document2, document3, document4)); + this.vectorStore.add(List.of(document1, document2, document3, document4)); FilterExpressionBuilder b = new FilterExpressionBuilder(); - List results = vectorStore.similaritySearch(SearchRequest.query("The World") + List results = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression((b.in("country", "UK", "NL")).build())); assertThat(results).hasSize(2); assertThat(results).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); - List results2 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results2 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression( b.and(b.or(b.gte("year", 2021), b.eq("country", "NL")), b.ne("city", "Amsterdam")).build())); @@ -141,17 +143,17 @@ public class CosmosDBVectorStoreIT { assertThat(results2).hasSize(1); assertThat(results2).extracting(Document::getId).containsExactlyInAnyOrder("1"); - List results3 = vectorStore.similaritySearch(SearchRequest.query("The World") + List results3 = this.vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(10) .withFilterExpression(b.and(b.eq("country", "US"), b.eq("year", 2020)).build())); assertThat(results3).hasSize(1); assertThat(results3).extracting(Document::getId).containsExactlyInAnyOrder("4"); - vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); + this.vectorStore.delete(List.of(document1.getId(), document2.getId(), document3.getId(), document4.getId())); // Perform a similarity search again - List results4 = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); + List results4 = this.vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)); // Verify the search results assertThat(results4).isEmpty(); @@ -191,6 +193,7 @@ public class CosmosDBVectorStoreIT { public VectorStoreObservationConvention observationConvention() { // Replace with an actual observation convention or a mock if needed return new VectorStoreObservationConvention() { + }; } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties index 20c6c6220..82882acde 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spring.ai.vectorstore.cosmosdb.databaseName=db spring.ai.vectorstore.cosmosdb.containerName=container spring.ai.vectorstore.cosmosdb.key=${COSMOSDB_AI_ENDPOINT} diff --git a/vector-stores/spring-ai-azure-store/pom.xml b/vector-stores/spring-ai-azure-store/pom.xml index fa819779c..25bf7e5f1 100644 --- a/vector-stores/spring-ai-azure-store/pom.xml +++ b/vector-stores/spring-ai-azure-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java index 2bcb1f3f0..ed127ea11 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.text.ParseException; diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 4a5ea3445..e5ed2c457 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -23,27 +23,6 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - import com.alibaba.fastjson2.JSONObject; import com.alibaba.fastjson2.TypeReference; import com.azure.core.util.Context; @@ -63,8 +42,28 @@ import com.azure.search.documents.models.IndexingResult; import com.azure.search.documents.models.SearchOptions; import com.azure.search.documents.models.VectorSearchOptions; import com.azure.search.documents.models.VectorizedQuery; - import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * Uses Azure Cognitive Search as a backing vector store. Documents can be preloaded into @@ -81,14 +80,14 @@ import io.micrometer.observation.ObservationRegistry; */ public class AzureVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final String DEFAULT_INDEX_NAME = "spring_ai_azure_vector_store"; + private static final Logger logger = LoggerFactory.getLogger(AzureVectorStore.class); private static final String SPRING_AI_VECTOR_CONFIG = "spring-ai-vector-config"; private static final String SPRING_AI_VECTOR_PROFILE = "spring-ai-vector-profile"; - public static final String DEFAULT_INDEX_NAME = "spring_ai_azure_vector_store"; - private static final String ID_FIELD_NAME = "id"; private static final String CONTENT_FIELD_NAME = "content"; @@ -109,16 +108,8 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements private final EmbeddingModel embeddingModel; - private SearchClient searchClient; - private final FilterExpressionConverter filterExpressionConverter; - private int defaultTopK = DEFAULT_TOP_K; - - private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; - - private String indexName = DEFAULT_INDEX_NAME; - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; @@ -134,32 +125,13 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements */ private final List filterMetadataFields; - public record MetadataField(String name, SearchFieldDataType fieldType) { + private SearchClient searchClient; - public static MetadataField text(String name) { - return new MetadataField(name, SearchFieldDataType.STRING); - } + private int defaultTopK = DEFAULT_TOP_K; - public static MetadataField int32(String name) { - return new MetadataField(name, SearchFieldDataType.INT32); - } + private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; - public static MetadataField int64(String name) { - return new MetadataField(name, SearchFieldDataType.INT64); - } - - public static MetadataField decimal(String name) { - return new MetadataField(name, SearchFieldDataType.DOUBLE); - } - - public static MetadataField bool(String name) { - return new MetadataField(name, SearchFieldDataType.BOOLEAN); - } - - public static MetadataField date(String name) { - return new MetadataField(name, SearchFieldDataType.DATE_TIME_OFFSET); - } - } + private String indexName = DEFAULT_INDEX_NAME; /** * Constructs a new AzureCognitiveSearchVectorStore. @@ -320,7 +292,7 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements Assert.notNull(request, "The search request must not be null."); - var searchEmbedding = embeddingModel.embed(request.getQuery()); + var searchEmbedding = this.embeddingModel.embed(request.getQuery()); final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding)) .setKNearestNeighborsCount(request.getTopK()) @@ -336,7 +308,7 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements searchOptions.setFilter(oDataFilter); } - final var searchResults = searchClient.search(null, searchOptions, Context.NONE); + final var searchResults = this.searchClient.search(null, searchOptions, Context.NONE); return searchResults.stream() .filter(result -> result.getScore() >= request.getSimilarityThreshold()) @@ -346,6 +318,7 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements Map metadata = (StringUtils.hasText(entry.metadata())) ? JSONObject.parseObject(entry.metadata(), new TypeReference>() { + }) : Map.of(); metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore()); @@ -359,12 +332,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements .collect(Collectors.toList()); } - /** - * Internal data structure for retrieving and storing documents. - */ - private record AzureSearchDocument(String id, String content, List embedding, String metadata) { - } - @Override public void afterPropertiesSet() throws Exception { @@ -426,4 +393,39 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements .withSimilarityMetric(this.initializeSchema ? VectorStoreSimilarityMetric.COSINE.value() : null); } + public record MetadataField(String name, SearchFieldDataType fieldType) { + + public static MetadataField text(String name) { + return new MetadataField(name, SearchFieldDataType.STRING); + } + + public static MetadataField int32(String name) { + return new MetadataField(name, SearchFieldDataType.INT32); + } + + public static MetadataField int64(String name) { + return new MetadataField(name, SearchFieldDataType.INT64); + } + + public static MetadataField decimal(String name) { + return new MetadataField(name, SearchFieldDataType.DOUBLE); + } + + public static MetadataField bool(String name) { + return new MetadataField(name, SearchFieldDataType.BOOLEAN); + } + + public static MetadataField date(String name) { + return new MetadataField(name, SearchFieldDataType.DATE_TIME_OFFSET); + } + + } + + /** + * Internal data structure for retrieving and storing documents. + */ + private record AzureSearchDocument(String id, String content, List embedding, String metadata) { + + } + } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java index 65d67fcc6..2bd638290 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.util.Date; @@ -204,4 +205,4 @@ public class AzureAiSearchFilterExpressionConverterTests { """); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index 9b9808177..03418bb20 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.io.IOException; @@ -56,14 +57,14 @@ import static org.hamcrest.Matchers.hasSize; @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -71,14 +72,24 @@ public class AzureVectorStoreIT { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), @@ -88,14 +99,14 @@ public class AzureVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); @@ -105,7 +116,7 @@ public class AzureVectorStoreIT { @Test public void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", @@ -194,7 +205,7 @@ public class AzureVectorStoreIT { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -247,11 +258,11 @@ public class AzureVectorStoreIT { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -272,13 +283,13 @@ public class AzureVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); @@ -309,14 +320,4 @@ public class AzureVectorStoreIT { } - private static String getText(String uri) { - var resource = new DefaultResourceLoader().getResource(uri); - try { - return resource.getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java index 10c876db0..6ce752d5f 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.azure; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.azure; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,10 +23,17 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.search.documents.indexes.SearchIndexClient; +import com.azure.search.documents.indexes.SearchIndexClientBuilder; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -47,13 +53,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.azure.core.credential.AzureKeyCredential; -import com.azure.search.documents.indexes.SearchIndexClient; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instrumentation AbstractObservationVectorStore in @@ -66,6 +66,9 @@ import io.micrometer.observation.tck.TestObservationRegistryAssert; @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") public class AzureVectorStoreObservationIT { + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -81,9 +84,6 @@ public class AzureVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -94,13 +94,13 @@ public class AzureVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-cassandra-store/pom.xml b/vector-stores/spring-ai-cassandra-store/pom.xml index c676627b9..1032f3635 100644 --- a/vector-stores/spring-ai-cassandra-store/pom.xml +++ b/vector-stores/spring-ai-cassandra-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java index f945b9db8..11ca9efe0 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/cassandra/SchemaUtil.java @@ -1,30 +1,29 @@ /* + * Copyright 2023-2024 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * - * See the NOTICE file distributed with this work for additional information - * regarding copyright ownership. */ + package org.springframework.ai.cassandra; +import java.time.Duration; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.time.Duration; - /** * @author Mick Semb Wever * @since 1.0.0 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java index 7dc82f370..c12c81dd6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.Row; @@ -31,18 +37,13 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - /** * Create a CassandraChatMemory like -CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); - + CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); + * * For example @see org.springframework.ai.chat.memory.CassandraChatMemory - * + * * @author Mick Semb Wever * @since 1.0.0 */ @@ -54,10 +55,6 @@ public final class CassandraChatMemory implements ChatMemory { private final PreparedStatement addUserStmt, addAssistantStmt, getStmt, deleteStmt; - public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { - return new CassandraChatMemory(conf); - } - public CassandraChatMemory(CassandraChatMemoryConfig config) { this.conf = config; this.conf.ensureSchemaExists(); @@ -67,6 +64,10 @@ public final class CassandraChatMemory implements ChatMemory { this.deleteStmt = prepareDeleteStmt(); } + public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { + return new CassandraChatMemory(conf); + } + @Override public void add(String conversationId, List messages) { final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); @@ -90,8 +91,8 @@ public final class CassandraChatMemory implements ChatMemory { PreparedStatement stmt; switch (msg.getMessageType()) { - case USER -> stmt = addUserStmt; - case ASSISTANT -> stmt = addAssistantStmt; + case USER -> stmt = this.addUserStmt; + case ASSISTANT -> stmt = this.addAssistantStmt; default -> throw new IllegalArgumentException("Cant add type " + msg); } @@ -115,7 +116,7 @@ public final class CassandraChatMemory implements ChatMemory { public void clear(String sessionId) { List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = deleteStmt.boundStatementBuilder(); + BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(); for (int k = 0; k < primaryKeys.size(); ++k) { SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); @@ -129,7 +130,7 @@ public final class CassandraChatMemory implements ChatMemory { public List get(String sessionId, int lastN) { List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = getStmt.boundStatementBuilder().setInt("lastN", lastN); + BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", lastN); for (int k = 0; k < primaryKeys.size(); ++k) { SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java index ab046e2ec..3c9f329b6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -32,42 +40,17 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; import com.datastax.oss.driver.api.querybuilder.schema.CreateTableWithOptions; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.cassandra.SchemaUtil; -import java.net.InetSocketAddress; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.function.Function; - /** * @author Mick Semb Wever * @since 1.0.0 */ public final class CassandraChatMemoryConfig { - private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryConfig.class); - - record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys) { - } - - public record SchemaColumn(String name, DataType type) { - - public GenericType javaType() { - return CodecRegistry.DEFAULT.codecFor(type).getJavaType(); - } - } - - /** Given a string sessionId, return the value for each primary key column. */ - public interface SessionIdToPrimaryKeysTranslator extends Function> { - - } - public static final String DEFAULT_KEYSPACE_NAME = "springframework"; public static final String DEFAULT_TABLE_NAME = "ai_chat_memory"; @@ -82,6 +65,8 @@ public final class CassandraChatMemoryConfig { public static final String DEFAULT_USER_COLUMN_NAME = "user"; + private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryConfig.class); + final CqlSession session; final Schema schema; @@ -90,16 +75,12 @@ public final class CassandraChatMemoryConfig { final String userColumn; + final SessionIdToPrimaryKeysTranslator primaryKeyTranslator; + private final Integer timeToLiveSeconds; private final boolean disallowSchemaChanges; - final SessionIdToPrimaryKeysTranslator primaryKeyTranslator; - - public static Builder builder() { - return new Builder(); - } - private CassandraChatMemoryConfig(Builder builder) { this.session = builder.session; this.schema = new Schema(builder.keyspace, builder.table, builder.partitionKeys, builder.clusteringKeys); @@ -110,6 +91,10 @@ public final class CassandraChatMemoryConfig { this.primaryKeyTranslator = builder.primaryKeyTranslator; } + public static Builder builder() { + return new Builder(); + } + SchemaColumn getPrimaryKeyColumn(int index) { return index < this.schema.partitionKeys().size() ? this.schema.partitionKeys().get(index) : this.schema.clusteringKeys().get(index - this.schema.partitionKeys().size()); @@ -121,6 +106,113 @@ public final class CassandraChatMemoryConfig { this.session.execute(SchemaBuilder.dropKeyspace(this.schema.keyspace).ifExists().build()); } + void ensureSchemaExists() { + if (!this.disallowSchemaChanges) { + SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); + ensureTableExists(); + ensureTableColumnsExist(); + SchemaUtil.checkSchemaAgreement(this.session); + } + else { + checkSchemaValid(); + } + } + + void checkSchemaValid() { + + Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), + "keyspace %s does not exist", this.schema.keyspace); + + Preconditions.checkState(this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .isPresent(), "table %s does not exist"); + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Preconditions.checkState(tableMetadata.getColumn(this.assistantColumn).isPresent(), "column %s does not exist", + this.assistantColumn); + + Preconditions.checkState(tableMetadata.getColumn(this.userColumn).isPresent(), "column %s does not exist", + this.userColumn); + } + + private void ensureTableExists() { + if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { + CreateTable createTable = null; + + CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) + .ifNotExists(); + + for (SchemaColumn partitionKey : this.schema.partitionKeys) { + createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, + partitionKey.type); + } + for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { + createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); + } + + String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name(); + + CreateTableWithOptions createTableWithOptions = createTable.withColumn(this.userColumn, DataTypes.TEXT) + .withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC) + // TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() is available + .withOption("compaction", Map.of("class", "UnifiedCompactionStrategy")); + + if (null != this.timeToLiveSeconds) { + createTableWithOptions = createTableWithOptions.withDefaultTimeToLiveSeconds(this.timeToLiveSeconds); + } + this.session.execute(createTableWithOptions.build()); + } + } + + private void ensureTableColumnsExist() { + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace()) + .get() + .getTable(this.schema.table()) + .get(); + + boolean addAssistantColumn = tableMetadata.getColumn(this.assistantColumn).isEmpty(); + boolean addUserColumn = tableMetadata.getColumn(this.userColumn).isEmpty(); + + if (addAssistantColumn || addUserColumn) { + AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table()); + if (addAssistantColumn) { + alterTable = alterTable.addColumn(this.assistantColumn, DataTypes.TEXT); + } + if (addUserColumn) { + alterTable = alterTable.addColumn(this.userColumn, DataTypes.TEXT); + } + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + this.session.execute(stmt); + } + } + + /** Given a string sessionId, return the value for each primary key column. */ + public interface SessionIdToPrimaryKeysTranslator extends Function> { + + } + + record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys) { + + } + + public record SchemaColumn(String name, DataType type) { + + public GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(this.type).getJavaType(); + } + + } + public static class Builder { private CqlSession session = null; @@ -226,14 +318,14 @@ public final class CassandraChatMemoryConfig { public CassandraChatMemoryConfig build() { - int primaryKeyColumns = partitionKeys.size() + clusteringKeys.size(); + int primaryKeyColumns = this.partitionKeys.size() + this.clusteringKeys.size(); int primaryKeysToBind = this.primaryKeyTranslator.apply(UUID.randomUUID().toString()).size(); Preconditions.checkArgument(primaryKeyColumns == primaryKeysToBind + 1, "The primaryKeyTranslator must always return one less element than the number of primary keys in total. The last clustering key remains undefined, expecting to be the timestamp for messages within sessionId. The sessionId can map to any primary key column (though it should map to a partition key column)."); Preconditions.checkArgument( - clusteringKeys.get(clusteringKeys.size() - 1).name().equals(DEFAULT_EXCHANGE_ID_NAME), + this.clusteringKeys.get(this.clusteringKeys.size() - 1).name().equals(DEFAULT_EXCHANGE_ID_NAME), "last clustering key must be the exchangeIdColumn"); return new CassandraChatMemoryConfig(this); @@ -241,92 +333,4 @@ public final class CassandraChatMemoryConfig { } - void ensureSchemaExists() { - if (!disallowSchemaChanges) { - SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); - ensureTableExists(); - ensureTableColumnsExist(); - SchemaUtil.checkSchemaAgreement(this.session); - } - else { - checkSchemaValid(); - } - } - - void checkSchemaValid() { - - Preconditions.checkState(session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), - "keyspace %s does not exist", this.schema.keyspace); - - Preconditions.checkState( - session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isPresent(), - "table %s does not exist"); - - TableMetadata tableMetadata = session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Preconditions.checkState(tableMetadata.getColumn(this.assistantColumn).isPresent(), "column %s does not exist", - this.assistantColumn); - - Preconditions.checkState(tableMetadata.getColumn(this.userColumn).isPresent(), "column %s does not exist", - this.userColumn); - } - - private void ensureTableExists() { - if (session.getMetadata().getKeyspace(schema.keyspace).get().getTable(this.schema.table).isEmpty()) { - CreateTable createTable = null; - - CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) - .ifNotExists(); - - for (SchemaColumn partitionKey : this.schema.partitionKeys) { - createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, - partitionKey.type); - } - for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { - createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); - } - - String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name(); - - CreateTableWithOptions createTableWithOptions = createTable.withColumn(this.userColumn, DataTypes.TEXT) - .withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC) - // TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() is available - .withOption("compaction", Map.of("class", "UnifiedCompactionStrategy")); - - if (null != this.timeToLiveSeconds) { - createTableWithOptions = createTableWithOptions.withDefaultTimeToLiveSeconds(this.timeToLiveSeconds); - } - this.session.execute(createTableWithOptions.build()); - } - } - - private void ensureTableColumnsExist() { - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace()) - .get() - .getTable(this.schema.table()) - .get(); - - boolean addAssistantColumn = tableMetadata.getColumn(this.assistantColumn).isEmpty(); - boolean addUserColumn = tableMetadata.getColumn(this.userColumn).isEmpty(); - - if (addAssistantColumn || addUserColumn) { - AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table()); - if (addAssistantColumn) { - alterTable = alterTable.addColumn(this.assistantColumn, DataTypes.TEXT); - } - if (addUserColumn) { - alterTable = alterTable.addColumn(this.userColumn, DataTypes.TEXT); - } - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); - logger.debug("Executing {}", stmt.getQuery()); - this.session.execute(stmt); - } - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java index aef2a56b2..ddb310409 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; @@ -26,12 +33,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.Filter.Value; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; -import java.util.Collection; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - /** * Converts {@link org.springframework.ai.vectorstore.filter.Filter.Expression} into CQL * where clauses. @@ -49,6 +50,24 @@ class CassandraFilterExpressionConverter extends AbstractFilterExpressionConvert .collect(Collectors.toMap((c) -> c.getName().asInternal(), Function.identity())); } + private static void doOperand(ExpressionType type, StringBuilder context) { + switch (type) { + case EQ -> context.append(" = "); + case NE -> context.append(" != "); + case GT -> context.append(" > "); + case GTE -> context.append(" >= "); + case IN -> context.append(" IN "); + case LT -> context.append(" < "); + case LTE -> context.append(" <= "); + // TODO SAI supports collections + // reach out to mck@apache.org if you'd like these implemented + // case CONTAINS -> context.append(" CONTAINS "); + // case CONTAINS_KEY -> context.append(" CONTAINS_KEY "); + default -> throw new UnsupportedOperationException( + String.format("Expression type %s not yet implemented. Patches welcome.", type)); + } + } + @Override protected void doKey(Key key, StringBuilder context) { String keyName = key.key(); @@ -68,24 +87,6 @@ class CassandraFilterExpressionConverter extends AbstractFilterExpressionConvert } } - private static void doOperand(ExpressionType type, StringBuilder context) { - switch (type) { - case EQ -> context.append(" = "); - case NE -> context.append(" != "); - case GT -> context.append(" > "); - case GTE -> context.append(" >= "); - case IN -> context.append(" IN "); - case LT -> context.append(" < "); - case LTE -> context.append(" <= "); - // TODO SAI supports collections - // reach out to mck@apache.org if you'd like these implemented - // case CONTAINS -> context.append(" CONTAINS "); - // case CONTAINS_KEY -> context.append(" CONTAINS_KEY "); - default -> throw new UnsupportedOperationException( - String.format("Expression type %s not yet implemented. Patches welcome.", type)); - } - } - private void doBinaryOperation(String operator, Filter.Expression expression, StringBuilder context) { this.convertOperand(expression.left(), context); context.append(operator); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 35bb49420..349463646 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,17 @@ package org.springframework.ai.vectorstore; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; @@ -29,9 +40,7 @@ import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,17 +59,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationCont import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - /** * The CassandraVectorStore is for managing and querying vector data in an Apache * Cassandra db. It offers functionalities like adding, deleting, and performing @@ -108,16 +106,6 @@ import java.util.concurrent.ConcurrentMap; */ public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { - /** - * Indexes are automatically created with COSINE. This can be changed manually via - * cqlsh - */ - public enum Similarity { - - COSINE, DOT_PRODUCT, EUCLIDEAN; - - } - public static final String SIMILARITY_FIELD_NAME = "similarity_score"; public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; @@ -128,6 +116,10 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); + private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, + VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, + Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); + private final CassandraVectorStoreConfig conf; private final EmbeddingModel embeddingModel; @@ -177,6 +169,15 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme this.batchingStrategy = batchingStrategy; } + private static Float[] toFloatArray(float[] embedding) { + Float[] embeddingFloat = new Float[embedding.length]; + int i = 0; + for (Float d : embedding) { + embeddingFloat[i++] = d.floatValue(); + } + return embeddingFloat; + } + @Override public void doAdd(List documents) { var futures = new CompletableFuture[documents.size()]; @@ -275,7 +276,7 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme } void checkSchemaValid() { - this.conf.checkSchemaValid(embeddingModel.dimensions()); + this.conf.checkSchemaValid(this.embeddingModel.dimensions()); } private Similarity getIndexSimilarity(TableMetadata metadata) { @@ -289,7 +290,7 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme private PreparedStatement prepareDeleteStatement() { Delete stmt = null; - DeleteSelection stmtStart = QueryBuilder.deleteFrom(conf.schema.keyspace(), conf.schema.table()); + DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); for (var c : this.conf.schema.partitionKeys()) { stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); @@ -344,7 +345,7 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase()) .append('(') - .append(conf.schema.embedding()) + .append(this.conf.schema.embedding()) .append(",?)") .toString(); @@ -377,15 +378,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme return this.conf.primaryKeyTranslator.apply(primaryKeyValues); } - private static Float[] toFloatArray(float[] embedding) { - Float[] embeddingFloat = new Float[embedding.length]; - int i = 0; - for (Float d : embedding) { - embeddingFloat[i++] = d.floatValue(); - } - return embeddingFloat; - } - @Override public Builder createObservationContextBuilder(String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CASSANDRA.value(), operationName) @@ -395,10 +387,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, - VectorStoreSimilarityMetric.COSINE, Similarity.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - Similarity.DOT_PRODUCT, VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.similarity)) { return this.similarity.name(); @@ -406,4 +394,14 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme return SIMILARITY_TYPE_MAPPING.get(this.similarity).value(); } + /** + * Indexes are automatically created with COSINE. This can be changed manually via + * cqlsh + */ + public enum Similarity { + + COSINE, DOT_PRODUCT, EUCLIDEAN; + + } + } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index e50e8008c..65bcba011 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.function.Function; +import java.util.stream.Stream; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -32,23 +44,11 @@ import com.datastax.oss.driver.api.querybuilder.schema.CreateTable; import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.lang.Nullable; import org.springframework.ai.cassandra.SchemaUtil; - -import java.net.InetSocketAddress; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.function.Function; -import java.util.stream.Stream; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import org.springframework.lang.Nullable; /** * Configuration for the Cassandra vector store. @@ -84,51 +84,6 @@ public class CassandraVectorStoreConfig implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStoreConfig.class); - record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, - String content, String embedding, String index, Set metadataColumns) { - - } - - public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { - public SchemaColumn(String name, DataType type) { - this(name, type, new SchemaColumnTags[0]); - } - - public GenericType javaType() { - return CodecRegistry.DEFAULT.codecFor(type).getJavaType(); - } - - public boolean indexed() { - for (SchemaColumnTags t : tags) { - if (SchemaColumnTags.INDEXED == t) { - return true; - } - } - return false; - } - } - - public enum SchemaColumnTags { - - INDEXED - - } - - /** - * Given a string document id, return the value for each primary key column. - * - * It is a requirement that an empty {@code List} returns an example formatted - * id - */ - public interface DocumentIdTranslator extends Function> { - - } - - /** Given a list of primary key column values, return the document id. */ - public interface PrimaryKeyTranslator extends Function, String> { - - } - final CqlSession session; final Schema schema; @@ -181,6 +136,232 @@ public class CassandraVectorStoreConfig implements AutoCloseable { this.session.execute(SchemaBuilder.dropKeyspace(this.schema.keyspace).ifExists().build()); } + void ensureSchemaExists(int vectorDimension) { + if (!this.disallowSchemaChanges) { + SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); + ensureTableExists(vectorDimension); + ensureTableColumnsExist(vectorDimension); + ensureIndexesExists(); + SchemaUtil.checkSchemaAgreement(this.session); + } + else { + checkSchemaValid(vectorDimension); + } + } + + void checkSchemaValid(int vectorDimension) { + + Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), + "keyspace %s does not exist", this.schema.keyspace); + + Preconditions.checkState(this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .isPresent(), "table %s does not exist"); + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.content).isPresent(), "column %s does not exist", + this.schema.content); + + Preconditions.checkState(tableMetadata.getColumn(this.schema.embedding).isPresent(), "column %s does not exist", + this.schema.embedding); + + for (SchemaColumn m : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(m.name()); + Preconditions.checkState(column.isPresent(), "column %s does not exist", m.name()); + + Preconditions.checkArgument(column.get().getType().equals(m.type()), + "Mismatching type on metadata column %s of %s vs %s", m.name(), column.get().getType(), m.type()); + + if (m.indexed()) { + Preconditions.checkState( + tableMetadata.getIndexes().values().stream().anyMatch((i) -> i.getTarget().equals(m.name())), + "index %s does not exist", m.name()); + } + } + + } + + private void ensureIndexesExists() { + { + SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(this.schema.embedding) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + } + Stream + .concat(this.schema.partitionKeys.stream(), + Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) + .filter((cs) -> cs.indexed()) + .forEach((metadata) -> { + + SimpleStatement indexStmt = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(metadata.name()) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + }); + } + + private void ensureTableExists(int vectorDimension) { + if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { + + CreateTable createTable = null; + + CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) + .ifNotExists(); + + for (SchemaColumn partitionKey : this.schema.partitionKeys) { + createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, + partitionKey.type); + } + for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { + createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); + } + + createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + createTable = createTable.withColumn(metadata.name(), metadata.type()); + } + + // https://datastax-oss.atlassian.net/browse/JAVA-3118 + // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, + // vectorDimension)); + + StringBuilder tableStmt = new StringBuilder(createTable.asCql()); + tableStmt.setLength(tableStmt.length() - 1); + tableStmt.append(',') + .append(this.schema.embedding) + .append(" vector)"); + logger.debug("Executing {}", tableStmt.toString()); + this.session.execute(tableStmt.toString()); + } + } + + private void ensureTableColumnsExist(int vectorDimension) { + + TableMetadata tableMetadata = this.session.getMetadata() + .getKeyspace(this.schema.keyspace) + .get() + .getTable(this.schema.table) + .get(); + + Set newColumns = new HashSet<>(); + boolean addContent = tableMetadata.getColumn(this.schema.content).isEmpty(); + boolean addEmbedding = tableMetadata.getColumn(this.schema.embedding).isEmpty(); + + for (SchemaColumn metadata : this.schema.metadataColumns) { + Optional column = tableMetadata.getColumn(metadata.name()); + if (column.isPresent()) { + + Preconditions.checkArgument(column.get().getType().equals(metadata.type()), + "Cannot change type on metadata column %s from %s to %s", metadata.name(), + column.get().getType(), metadata.type()); + } + else { + newColumns.add(metadata); + } + } + + if (!newColumns.isEmpty() || addContent || addEmbedding) { + AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace, this.schema.table); + for (SchemaColumn metadata : newColumns) { + alterTable = alterTable.addColumn(metadata.name(), metadata.type()); + } + if (addContent) { + alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); + } + if (addEmbedding) { + // special case for embedding column, bc JAVA-3118, as above + StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); + if (newColumns.isEmpty() && !addContent) { + alterTableStmt.append(" ADD ("); + } + else { + alterTableStmt.setLength(alterTableStmt.length() - 1); + alterTableStmt.append(','); + } + alterTableStmt.append(this.schema.embedding) + .append(" vector)"); + + logger.debug("Executing {}", alterTableStmt.toString()); + this.session.execute(alterTableStmt.toString()); + } + else { + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + this.session.execute(stmt); + } + } + } + + public enum SchemaColumnTags { + + INDEXED + + } + + /** + * Given a string document id, return the value for each primary key column. + * + * It is a requirement that an empty {@code List} returns an example formatted + * id + */ + public interface DocumentIdTranslator extends Function> { + + } + + /** Given a list of primary key column values, return the document id. */ + public interface PrimaryKeyTranslator extends Function, String> { + + } + + record Schema(String keyspace, String table, List partitionKeys, List clusteringKeys, + String content, String embedding, String index, Set metadataColumns) { + + } + + public record SchemaColumn(String name, DataType type, SchemaColumnTags... tags) { + + public SchemaColumn(String name, DataType type) { + this(name, type, new SchemaColumnTags[0]); + } + + public GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(this.type).getJavaType(); + } + + public boolean indexed() { + for (SchemaColumnTags t : this.tags) { + if (SchemaColumnTags.INDEXED == t) { + return true; + } + } + return false; + } + + } + public static class Builder { private CqlSession session = null; @@ -383,183 +564,4 @@ public class CassandraVectorStoreConfig implements AutoCloseable { } - void ensureSchemaExists(int vectorDimension) { - if (!this.disallowSchemaChanges) { - SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace); - ensureTableExists(vectorDimension); - ensureTableColumnsExist(vectorDimension); - ensureIndexesExists(); - SchemaUtil.checkSchemaAgreement(session); - } - else { - checkSchemaValid(vectorDimension); - } - } - - void checkSchemaValid(int vectorDimension) { - - Preconditions.checkState(this.session.getMetadata().getKeyspace(this.schema.keyspace).isPresent(), - "keyspace %s does not exist", this.schema.keyspace); - - Preconditions.checkState(this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .isPresent(), "table %s does not exist"); - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Preconditions.checkState(tableMetadata.getColumn(this.schema.content).isPresent(), "column %s does not exist", - this.schema.content); - - Preconditions.checkState(tableMetadata.getColumn(this.schema.embedding).isPresent(), "column %s does not exist", - this.schema.embedding); - - for (SchemaColumn m : this.schema.metadataColumns) { - Optional column = tableMetadata.getColumn(m.name()); - Preconditions.checkState(column.isPresent(), "column %s does not exist", m.name()); - - Preconditions.checkArgument(column.get().getType().equals(m.type()), - "Mismatching type on metadata column %s of %s vs %s", m.name(), column.get().getType(), m.type()); - - if (m.indexed()) { - Preconditions.checkState( - tableMetadata.getIndexes().values().stream().anyMatch((i) -> i.getTarget().equals(m.name())), - "index %s does not exist", m.name()); - } - } - - } - - private void ensureIndexesExists() { - { - SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) - .ifNotExists() - .custom("StorageAttachedIndex") - .onTable(this.schema.keyspace, this.schema.table) - .andColumn(this.schema.embedding) - .build(); - - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); - } - Stream - .concat(this.schema.partitionKeys.stream(), - Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) - .filter((cs) -> cs.indexed()) - .forEach((metadata) -> { - - SimpleStatement indexStmt = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) - .ifNotExists() - .custom("StorageAttachedIndex") - .onTable(this.schema.keyspace, this.schema.table) - .andColumn(metadata.name()) - .build(); - - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); - }); - } - - private void ensureTableExists(int vectorDimension) { - if (this.session.getMetadata().getKeyspace(this.schema.keyspace).get().getTable(this.schema.table).isEmpty()) { - - CreateTable createTable = null; - - CreateTableStart createTableStart = SchemaBuilder.createTable(this.schema.keyspace, this.schema.table) - .ifNotExists(); - - for (SchemaColumn partitionKey : this.schema.partitionKeys) { - createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, - partitionKey.type); - } - for (SchemaColumn clusteringKey : this.schema.clusteringKeys) { - createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); - } - - createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); - - for (SchemaColumn metadata : this.schema.metadataColumns) { - createTable = createTable.withColumn(metadata.name(), metadata.type()); - } - - // https://datastax-oss.atlassian.net/browse/JAVA-3118 - // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, - // vectorDimension)); - - StringBuilder tableStmt = new StringBuilder(createTable.asCql()); - tableStmt.setLength(tableStmt.length() - 1); - tableStmt.append(',') - .append(this.schema.embedding) - .append(" vector)"); - logger.debug("Executing {}", tableStmt.toString()); - this.session.execute(tableStmt.toString()); - } - } - - private void ensureTableColumnsExist(int vectorDimension) { - - TableMetadata tableMetadata = this.session.getMetadata() - .getKeyspace(this.schema.keyspace) - .get() - .getTable(this.schema.table) - .get(); - - Set newColumns = new HashSet<>(); - boolean addContent = tableMetadata.getColumn(this.schema.content).isEmpty(); - boolean addEmbedding = tableMetadata.getColumn(this.schema.embedding).isEmpty(); - - for (SchemaColumn metadata : this.schema.metadataColumns) { - Optional column = tableMetadata.getColumn(metadata.name()); - if (column.isPresent()) { - - Preconditions.checkArgument(column.get().getType().equals(metadata.type()), - "Cannot change type on metadata column %s from %s to %s", metadata.name(), - column.get().getType(), metadata.type()); - } - else { - newColumns.add(metadata); - } - } - - if (!newColumns.isEmpty() || addContent || addEmbedding) { - AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace, this.schema.table); - for (SchemaColumn metadata : newColumns) { - alterTable = alterTable.addColumn(metadata.name(), metadata.type()); - } - if (addContent) { - alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); - } - if (addEmbedding) { - // special case for embedding column, bc JAVA-3118, as above - StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); - if (newColumns.isEmpty() && !addContent) { - alterTableStmt.append(" ADD ("); - } - else { - alterTableStmt.setLength(alterTableStmt.length() - 1); - alterTableStmt.append(','); - } - alterTableStmt.append(this.schema.embedding) - .append(" vector)"); - - logger.debug("Executing {}", alterTableStmt.toString()); - this.session.execute(alterTableStmt.toString()); - } - else { - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); - logger.debug("Executing {}", stmt.getQuery()); - this.session.execute(stmt); - } - } - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java index 8750aa3fe..cc70cd97a 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java index 9e6426113..802b046b5 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/chat/memory/CassandraChatMemoryIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; import java.time.Duration; @@ -21,11 +22,11 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.CassandraImage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java index 9cf501c20..21d07e4be 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Collection; @@ -20,8 +21,8 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.internal.core.metadata.schema.DefaultColumnMetadata; import org.junit.jupiter.api.Assertions; @@ -47,6 +48,16 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR */ class CassandraFilterExpressionConverterTests { + private static final CqlIdentifier T = CqlIdentifier.fromInternal("test"); + + private static final Collection COLUMNS = Set.of( + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("id"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("content"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("country"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("genre"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("drama"), DataTypes.TEXT, false), + new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("year"), DataTypes.SMALLINT, false)); + @Test void testEQOnPartition() { @@ -199,14 +210,4 @@ class CassandraFilterExpressionConverterTests { assertThat(vectorExpr).isEqualTo("\"'country 1 2 3'\" = 'BG'"); } - private static final CqlIdentifier T = CqlIdentifier.fromInternal("test"); - - private static final Collection COLUMNS = Set.of( - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("id"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("content"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("country"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("genre"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("drama"), DataTypes.TEXT, false), - new DefaultColumnMetadata(T, T, CqlIdentifier.fromInternal("year"), DataTypes.SMALLINT, false)); - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index 177c50e53..b7684bcc6 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -36,12 +37,12 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; +import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -91,6 +92,64 @@ class CassandraRichSchemaVectorStoreIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestApplication.class); + static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, + List columnOverrides) throws IOException { + + Optional wikiOverride = columnOverrides.stream() + .filter((f) -> "wiki".equals(f.name())) + .findFirst(); + + Optional langOverride = columnOverrides.stream() + .filter((f) -> "language".equals(f.name())) + .findFirst(); + + Optional titleOverride = columnOverrides.stream() + .filter((f) -> "title".equals(f.name())) + .findFirst(); + + Optional chunkNoOverride = columnOverrides.stream() + .filter((f) -> "chunk_no".equals(f.name())) + .findFirst(); + + SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); + SchemaColumn langSC = langOverride.orElse(new SchemaColumn("language", DataTypes.TEXT)); + SchemaColumn titleSC = titleOverride.orElse(new SchemaColumn("title", DataTypes.TEXT)); + SchemaColumn chunkNoSC = chunkNoOverride.orElse(new SchemaColumn("chunk_no", DataTypes.INT)); + + List partitionKeys = List.of(wikiSC, langSC, titleSC); + List clusteringKeys = List.of(chunkNoSC); + + CassandraVectorStoreConfig.Builder builder = CassandraVectorStoreConfig.builder() + .withCqlSession(context.getBean(CqlSession.class)) + .withKeyspaceName("test_wikidata") + .withTableName("articles") + .withPartitionKeys(partitionKeys) + .withClusteringKeys(clusteringKeys) + .withContentColumnName("body") + .withEmbeddingColumnName("all_minilm_l6_v2_embedding") + .withIndexName("all_minilm_l6_v2_ann") + + .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), + new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) + + // this store uses '§¶' as a deliminator in the document id between db columns + // 'title' and 'chunk_no' + .withPrimaryKeyTranslator((List primaryKeys) -> { + if (primaryKeys.isEmpty()) { + return "test§¶0"; + } + return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + }) + .withDocumentIdTranslator((id) -> { + String[] parts = id.split("§¶"); + String title = parts[0]; + int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; + return List.of("simplewiki", "en", title, chunk_no); + }); + + return builder; + } + @Test void ensureSchemaCreation() { this.contextRunner.run(context -> { @@ -157,7 +216,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -192,7 +251,7 @@ class CassandraRichSchemaVectorStoreIT { int docsPerAdd = 12; // 128; int rounds = 3; - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = new CassandraVectorStore( storeBuilder(context, List.of()).withFixedThreadPoolExecutorSize(nThreads).build(), @@ -231,7 +290,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void searchWithPartitionFilter() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -282,7 +341,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void unsearchableFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -301,7 +360,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -366,7 +425,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void searchWithFilterOnPrimaryKeys() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { List overrides = List.of( new SchemaColumn("title", DataTypes.TEXT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED), @@ -402,7 +461,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -453,7 +512,7 @@ class CassandraRichSchemaVectorStoreIT { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createStore(context, false).store()) { store.add(documents); @@ -483,27 +542,6 @@ class CassandraRichSchemaVectorStoreIT { }); } - @SpringBootConfiguration - @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) - public static class TestApplication { - - @Bean - public EmbeddingModel embeddingModel() { - // default is ONNX all-MiniLM-L6-v2 - return new TransformersEmbeddingModel(); - } - - @Bean - public CqlSession cqlSession() { - return new CqlSessionBuilder() - // comment next two lines out to connect to a local C* cluster - .addContactPoint(cassandraContainer.getContactPoint()) - .withLocalDatacenter(cassandraContainer.getLocalDatacenter()) - .build(); - } - - } - private StoreWrapper createStore(ApplicationContext context, boolean disallowSchemaCreation) throws IOException { @@ -526,64 +564,6 @@ class CassandraRichSchemaVectorStoreIT { return new StoreWrapper(new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)), conf); } - static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, - List columnOverrides) throws IOException { - - Optional wikiOverride = columnOverrides.stream() - .filter((f) -> "wiki".equals(f.name())) - .findFirst(); - - Optional langOverride = columnOverrides.stream() - .filter((f) -> "language".equals(f.name())) - .findFirst(); - - Optional titleOverride = columnOverrides.stream() - .filter((f) -> "title".equals(f.name())) - .findFirst(); - - Optional chunkNoOverride = columnOverrides.stream() - .filter((f) -> "chunk_no".equals(f.name())) - .findFirst(); - - SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); - SchemaColumn langSC = langOverride.orElse(new SchemaColumn("language", DataTypes.TEXT)); - SchemaColumn titleSC = titleOverride.orElse(new SchemaColumn("title", DataTypes.TEXT)); - SchemaColumn chunkNoSC = chunkNoOverride.orElse(new SchemaColumn("chunk_no", DataTypes.INT)); - - List partitionKeys = List.of(wikiSC, langSC, titleSC); - List clusteringKeys = List.of(chunkNoSC); - - CassandraVectorStoreConfig.Builder builder = CassandraVectorStoreConfig.builder() - .withCqlSession(context.getBean(CqlSession.class)) - .withKeyspaceName("test_wikidata") - .withTableName("articles") - .withPartitionKeys(partitionKeys) - .withClusteringKeys(clusteringKeys) - .withContentColumnName("body") - .withEmbeddingColumnName("all_minilm_l6_v2_embedding") - .withIndexName("all_minilm_l6_v2_ann") - - .addMetadataColumns(new SchemaColumn("revision", DataTypes.INT), - new SchemaColumn("id", DataTypes.INT, CassandraVectorStoreConfig.SchemaColumnTags.INDEXED)) - - // this store uses '§¶' as a deliminator in the document id between db columns - // 'title' and 'chunk_no' - .withPrimaryKeyTranslator((List primaryKeys) -> { - if (primaryKeys.isEmpty()) { - return "test§¶0"; - } - return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); - }) - .withDocumentIdTranslator((id) -> { - String[] parts = id.split("§¶"); - String title = parts[0]; - int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; - return List.of("simplewiki", "en", title, chunk_no); - }); - - return builder; - } - private void executeCqlFile(ApplicationContext context, String filename) throws IOException { logger.info("executing {}", filename); @@ -599,7 +579,29 @@ class CassandraRichSchemaVectorStoreIT { } } + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Bean + public EmbeddingModel embeddingModel() { + // default is ONNX all-MiniLM-L6-v2 + return new TransformersEmbeddingModel(); + } + + @Bean + public CqlSession cqlSession() { + return new CqlSessionBuilder() + // comment next two lines out to connect to a local C* cluster + .addContactPoint(cassandraContainer.getContactPoint()) + .withLocalDatacenter(cassandraContainer.getLocalDatacenter()) + .build(); + } + + } + public record StoreWrapper(K store, V conf) { + } } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index 2c7571516..03dd67c27 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -29,11 +30,11 @@ import com.datastax.oss.driver.api.core.servererrors.SyntaxError; import com.datastax.oss.driver.api.core.type.DataTypes; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.springframework.ai.CassandraImage; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -84,6 +85,26 @@ class CassandraVectorStoreIT { } } + private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { + return CassandraVectorStoreConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); + } + + private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { + CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) + .addMetadataColumns(metadataFields); + + return createTestStore(context, builder); + } + + private static CassandraVectorStore createTestStore(ApplicationContext context, + CassandraVectorStoreConfig.Builder builder) { + CassandraVectorStoreConfig conf = builder.build(); + conf.dropKeyspace(); + return new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)); + } + @Test void ensureBeanGetsCreated() { this.contextRunner.run(context -> { @@ -96,7 +117,7 @@ class CassandraVectorStoreIT { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("meta1", DataTypes.TEXT), new SchemaColumn("meta2", DataTypes.TEXT))) { @@ -132,7 +153,7 @@ class CassandraVectorStoreIT { @Test void addAndSearchReturnEmbeddings() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) .returnEmbeddings(); @@ -168,7 +189,7 @@ class CassandraVectorStoreIT { @Test void searchWithPartitionFilter() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("year", DataTypes.SMALLINT, SchemaColumnTags.INDEXED))) { @@ -224,7 +245,7 @@ class CassandraVectorStoreIT { @Test void unsearchableFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -251,7 +272,7 @@ class CassandraVectorStoreIT { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("country", DataTypes.TEXT, SchemaColumnTags.INDEXED), @@ -314,7 +335,7 @@ class CassandraVectorStoreIT { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -351,7 +372,7 @@ class CassandraVectorStoreIT { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { store.add(documents()); @@ -414,24 +435,4 @@ class CassandraVectorStoreIT { } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - } - - private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) { - CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class)) - .addMetadataColumns(metadataFields); - - return createTestStore(context, builder); - } - - private static CassandraVectorStore createTestStore(ApplicationContext context, - CassandraVectorStoreConfig.Builder builder) { - CassandraVectorStoreConfig conf = builder.build(); - conf.dropKeyspace(); - return new CassandraVectorStore(conf, context.getBean(EmbeddingModel.class)); - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java index a8cc74dfb..e92349efa 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.type.DataTypes; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.CassandraImage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,17 +49,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.datastax.oss.driver.api.core.CqlSession; -import com.datastax.oss.driver.api.core.CqlSessionBuilder; -import com.datastax.oss.driver.api.core.type.DataTypes; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -80,16 +80,22 @@ public class CassandraVectorStoreObservationIT { } } + private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { + return CassandraVectorStoreConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); + } + @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -193,10 +199,4 @@ public class CassandraVectorStoreObservationIT { } - private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) { - return CassandraVectorStoreConfig.builder() - .withCqlSession(cqlSession) - .withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME); - } - } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java index 301b61c49..7189351da 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql index c6f6cf17a..86a8d93fb 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_full_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql index d2f3fcd62..42724e314 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_0_schema.cql @@ -1 +1,17 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql index cb1a53582..5b0064c30 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_1_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql index 5853b2274..759374499 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_2_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql index a605116ca..673a77e68 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_3_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql index 68b4583c4..564eb2333 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql +++ b/vector-stores/spring-ai-cassandra-store/src/test/resources/test_wiki_partial_4_schema.cql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + CREATE KEYSPACE IF NOT EXISTS test_wikidata WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; CREATE TABLE IF NOT EXISTS test_wikidata.articles ( diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index c057ea339..51b1fa4d8 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index ced7e44ff..8d4e5b036 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chroma; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.springframework.ai.chroma.ChromaApi.QueryRequest.Include; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; @@ -36,10 +40,6 @@ import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestClient; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - /** * Single-class Chroma API implementation based on the (unofficial) Chroma REST API. * @@ -54,10 +54,10 @@ public class ChromaApi { // Regular expression pattern that looks for a message. private static Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\""); - private RestClient restClient; - private final ObjectMapper objectMapper; + private RestClient restClient; + private String keyToken; public ChromaApi(String baseUrl) { @@ -99,164 +99,6 @@ public class ChromaApi { return this; } - /** - * Chroma embedding collection. - * - * @param id Collection Id. - * @param name The name of the collection. - * @param metadata Metadata associated with the collection. - */ - public record Collection(String id, String name, Map metadata) { - } - - /** - * Request to create a new collection with the given name and metadata. - * - * @param name The name of the collection to create. - * @param metadata Optional metadata to associate with the collection. - */ - public record CreateCollectionRequest(String name, Map metadata) { - public CreateCollectionRequest(String name) { - this(name, new HashMap<>(Map.of("hnsw:space", "cosine"))); - } - } - - /** - * Add embeddings to the chroma data store. - * - * @param ids The ids of the embeddings to add. - * @param embeddings The embeddings to add. - * @param metadata The metadata to associate with the embeddings. When querying, you - * can filter on this metadata. - * @param documents The documents contents to associate with the embeddings. - */ - public record AddEmbeddingsRequest(List ids, List embeddings, - @JsonProperty("metadatas") List> metadata, List documents) { - - // Convenance for adding a single embedding. - public AddEmbeddingsRequest(String id, float[] embedding, Map metadata, String document) { - this(List.of(id), List.of(embedding), List.of(metadata), List.of(document)); - } - } - - /** - * Request to delete embedding from a collection. - * - * @param ids The ids of the embeddings to delete. (Optional) - * @param where Condition to filter items to delete based on metadata values. - * (Optional) - */ - public record DeleteEmbeddingsRequest(List ids, Map where) { - public DeleteEmbeddingsRequest(List ids) { - this(ids, Map.of()); - } - } - - /** - * Get embeddings from a collection. - * - * @param ids IDs of the embeddings to get. - * @param where Condition to filter results based on metadata values. - * @param limit Limit on the number of collection embeddings to get. - * @param offset Offset on the embeddings to get. - * @param include A list of what to include in the results. Can contain "embeddings", - * "metadatas", "documents", "distances". Ids are always included. Defaults to - * [metadatas, documents, distances]. - */ - public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, - List include) { - - public GetEmbeddingsRequest(List ids) { - this(ids, Map.of(), 10, 0, Include.all); - } - - public GetEmbeddingsRequest(List ids, Map where) { - this(ids, where, 10, 0, Include.all); - } - - public GetEmbeddingsRequest(List ids, Map where, int limit, int offset) { - this(ids, where, limit, offset, Include.all); - } - } - - /** - * Object containing the get embedding results. - * - * @param ids List of document ids. One for each returned document. - * @param embeddings List of document embeddings. One for each returned document. - * @param documents List of document contents. One for each returned document. - * @param metadata List of document metadata. One for each returned document. - */ - public record GetEmbeddingResponse(List ids, List embeddings, List documents, - @JsonProperty("metadatas") List> metadata) { - } - - /** - * Request to get the nResults nearest neighbor embeddings for provided - * queryEmbeddings. - * - * @param queryEmbeddings The embeddings to get the closes neighbors of. - * @param nResults The number of neighbors to return for each query_embedding or - * query_texts. - * @param where Condition to filter results based on metadata values. - * @param include A list of what to include in the results. Can contain "embeddings", - * "metadatas", "documents", "distances". Ids are always included. Defaults to - * [metadatas, documents, distances]. - */ - public record QueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, - @JsonProperty("n_results") int nResults, Map where, List include) { - - public enum Include { - - @JsonProperty("metadatas") - METADATAS, - - @JsonProperty("documents") - DOCUMENTS, - - @JsonProperty("distances") - DISTANCES, - - @JsonProperty("embeddings") - EMBEDDINGS; - - public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); - - } - - /** - * Convenience to query for a single embedding instead of a batch of embeddings. - */ - public QueryRequest(float[] queryEmbedding, int nResults) { - this(List.of(queryEmbedding), nResults, Map.of(), Include.all); - } - - public QueryRequest(float[] queryEmbedding, int nResults, Map where) { - this(List.of(queryEmbedding), nResults, where, Include.all); - } - } - - /** - * A QueryResponse object containing the query results. - * - * @param ids List of list of document ids. One for each returned document. - * @param embeddings List of list of document embeddings. One for each returned - * document. - * @param documents List of list of document contents. One for each returned document. - * @param metadata List of list of document metadata. One for each returned document. - * @param distances List of list of search distances. One for each returned document. - */ - public record QueryResponse(List> ids, List> embeddings, List> documents, - @JsonProperty("metadatas") List>> metadata, List> distances) { - } - - /** - * Single query embedding response. - */ - public record Embedding(String id, float[] embedding, String document, Map metadata, - Double distances) { - } - public List toEmbeddingResponseList(QueryResponse queryResponse) { List result = new ArrayList<>(); @@ -271,10 +113,6 @@ public class ChromaApi { return result; } - // - // Chroma Client API (https://docs.trychroma.com/js_reference/Client) - // - public Collection createCollection(CreateCollectionRequest createCollectionRequest) { return this.restClient.post() @@ -330,10 +168,6 @@ public class ChromaApi { } } - private static class CollectionList extends ArrayList { - - } - public List listCollections() { return this.restClient.get() @@ -344,10 +178,6 @@ public class ChromaApi { .getBody(); } - // - // Chroma Collection API (https://docs.trychroma.com/js_reference/Collection) - // - public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) { this.restClient.post() @@ -366,6 +196,7 @@ public class ChromaApi { .body(deleteRequest) .retrieve() .toEntity(new ParameterizedTypeReference>() { + }) .getBody(); } @@ -391,6 +222,10 @@ public class ChromaApi { .getBody(); } + // + // Chroma Client API (https://docs.trychroma.com/js_reference/Client) + // + public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { return this.restClient.post() @@ -442,4 +277,181 @@ public class ChromaApi { return ""; } + /** + * Chroma embedding collection. + * + * @param id Collection Id. + * @param name The name of the collection. + * @param metadata Metadata associated with the collection. + */ + public record Collection(String id, String name, Map metadata) { + + } + + /** + * Request to create a new collection with the given name and metadata. + * + * @param name The name of the collection to create. + * @param metadata Optional metadata to associate with the collection. + */ + public record CreateCollectionRequest(String name, Map metadata) { + + public CreateCollectionRequest(String name) { + this(name, new HashMap<>(Map.of("hnsw:space", "cosine"))); + } + + } + + // + // Chroma Collection API (https://docs.trychroma.com/js_reference/Collection) + // + + /** + * Add embeddings to the chroma data store. + * + * @param ids The ids of the embeddings to add. + * @param embeddings The embeddings to add. + * @param metadata The metadata to associate with the embeddings. When querying, you + * can filter on this metadata. + * @param documents The documents contents to associate with the embeddings. + */ + public record AddEmbeddingsRequest(List ids, List embeddings, + @JsonProperty("metadatas") List> metadata, List documents) { + + // Convenance for adding a single embedding. + public AddEmbeddingsRequest(String id, float[] embedding, Map metadata, String document) { + this(List.of(id), List.of(embedding), List.of(metadata), List.of(document)); + } + + } + + /** + * Request to delete embedding from a collection. + * + * @param ids The ids of the embeddings to delete. (Optional) + * @param where Condition to filter items to delete based on metadata values. + * (Optional) + */ + public record DeleteEmbeddingsRequest(List ids, Map where) { + + public DeleteEmbeddingsRequest(List ids) { + this(ids, Map.of()); + } + + } + + /** + * Get embeddings from a collection. + * + * @param ids IDs of the embeddings to get. + * @param where Condition to filter results based on metadata values. + * @param limit Limit on the number of collection embeddings to get. + * @param offset Offset on the embeddings to get. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, + List include) { + + public GetEmbeddingsRequest(List ids) { + this(ids, Map.of(), 10, 0, Include.all); + } + + public GetEmbeddingsRequest(List ids, Map where) { + this(ids, where, 10, 0, Include.all); + } + + public GetEmbeddingsRequest(List ids, Map where, int limit, int offset) { + this(ids, where, limit, offset, Include.all); + } + + } + + /** + * Object containing the get embedding results. + * + * @param ids List of document ids. One for each returned document. + * @param embeddings List of document embeddings. One for each returned document. + * @param documents List of document contents. One for each returned document. + * @param metadata List of document metadata. One for each returned document. + */ + public record GetEmbeddingResponse(List ids, List embeddings, List documents, + @JsonProperty("metadatas") List> metadata) { + + } + + /** + * Request to get the nResults nearest neighbor embeddings for provided + * queryEmbeddings. + * + * @param queryEmbeddings The embeddings to get the closes neighbors of. + * @param nResults The number of neighbors to return for each query_embedding or + * query_texts. + * @param where Condition to filter results based on metadata values. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record QueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, + @JsonProperty("n_results") int nResults, Map where, List include) { + + /** + * Convenience to query for a single embedding instead of a batch of embeddings. + */ + public QueryRequest(float[] queryEmbedding, int nResults) { + this(List.of(queryEmbedding), nResults, Map.of(), Include.all); + } + + public QueryRequest(float[] queryEmbedding, int nResults, Map where) { + this(List.of(queryEmbedding), nResults, where, Include.all); + } + + public enum Include { + + @JsonProperty("metadatas") + METADATAS, + + @JsonProperty("documents") + DOCUMENTS, + + @JsonProperty("distances") + DISTANCES, + + @JsonProperty("embeddings") + EMBEDDINGS; + + public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); + + } + + } + + /** + * A QueryResponse object containing the query results. + * + * @param ids List of list of document ids. One for each returned document. + * @param embeddings List of list of document embeddings. One for each returned + * document. + * @param documents List of list of document contents. One for each returned document. + * @param metadata List of list of document metadata. One for each returned document. + * @param distances List of list of search distances. One for each returned document. + */ + public record QueryResponse(List> ids, List> embeddings, List> documents, + @JsonProperty("metadatas") List>> metadata, List> distances) { + + } + + /** + * Single query embedding response. + */ + public record Embedding(String id, float[] embedding, String document, Map metadata, + Double distances) { + + } + + private static class CollectionList extends ArrayList { + + } + } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java index cd60e83dc..526fc30cf 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index f77786806..7afb13c4b 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -22,6 +22,11 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest; @@ -43,11 +48,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.json.JsonMapper; -import io.micrometer.observation.ObservationRegistry; - /** * {@link ChromaVectorStore} is a concrete implementation of the {@link VectorStore} * interface. It is responsible for adding, deleting, and searching documents based on @@ -229,4 +229,4 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index 9cbbaa4a4..51208d3d9 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index c1c934df7..bfc843402 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.chroma; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.chroma; import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Collection; @@ -31,9 +34,8 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -52,57 +54,58 @@ public class ChromaApiIT { @BeforeEach public void beforeEach() { - chroma.listCollections().stream().forEach(c -> chroma.deleteCollection(c.name())); + this.chroma.listCollections().stream().forEach(c -> this.chroma.deleteCollection(c.name())); } @Test public void testClientWithMetadata() { Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5); - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); + var newCollection = this.chroma + .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); } @Test public void testClient() { - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); - var getCollection = chroma.getCollection("TestCollection"); + var getCollection = this.chroma.getCollection("TestCollection"); assertThat(getCollection).isNotNull(); assertThat(getCollection.name()).isEqualTo("TestCollection"); assertThat(getCollection.id()).isEqualTo(newCollection.id()); - List collections = chroma.listCollections(); + List collections = this.chroma.listCollections(); assertThat(collections).hasSize(1); assertThat(collections.get(0).id()).isEqualTo(newCollection.id()); - chroma.deleteCollection(newCollection.name()); - assertThat(chroma.listCollections()).hasSize(0); + this.chroma.deleteCollection(newCollection.name()); + assertThat(this.chroma.listCollections()).hasSize(0); } @Test public void testCollection() { - var newCollection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); - assertThat(chroma.countEmbeddings(newCollection.id())).isEqualTo(0); + var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(0); var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"), List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }), List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)), List.of("Hello World", "Big World")); - chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); + this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f }, Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World"); - chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); + this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); - assertThat(chroma.countEmbeddings(newCollection.id())).isEqualTo(3); + assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(3); - var queryResult = chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + var queryResult = this.chroma.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "key2" : { "$eq": true } } @@ -111,14 +114,14 @@ public class ChromaApiIT { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3"); // Update existing embedding. - chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, + this.chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); assertThat(result.ids().get(0)).isEqualTo("id2"); - queryResult = chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "key2" : { "$eq": true } } @@ -130,7 +133,7 @@ public class ChromaApiIT { @Test public void testQueryWhere() { - var collection = chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var collection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f }, Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020), @@ -143,24 +146,24 @@ public class ChromaApiIT { Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023), "The World is Big and Salvation Lurks Around the Corner"); - chroma.upsertEmbeddings(collection.id(), add1); - chroma.upsertEmbeddings(collection.id(), add2); - chroma.upsertEmbeddings(collection.id(), add3); + this.chroma.upsertEmbeddings(collection.id(), add1); + this.chroma.upsertEmbeddings(collection.id(), add2); + this.chroma.upsertEmbeddings(collection.id(), add3); - assertThat(chroma.countEmbeddings(collection.id())).isEqualTo(3); + assertThat(this.chroma.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3"); - var chromaEmbeddings = chroma.toEmbeddingResponseList(queryResult); + var chromaEmbeddings = this.chroma.toEmbeddingResponseList(queryResult); assertThat(chromaEmbeddings).hasSize(3); assertThat(chromaEmbeddings).hasSize(3); - queryResult = chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -171,8 +174,8 @@ public class ChromaApiIT { assertThat(queryResult.ids().get(0)).hasSize(2); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3"); - queryResult = chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, chroma.where(""" + queryResult = this.chroma.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java index b0083010e..4c748dabb 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -32,10 +36,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.MountableFile; + +import static org.assertj.core.api.Assertions.assertThat; /** * ChromaDB with Basic Authentication: @@ -68,7 +70,7 @@ public class BasicAuthChromaWhereIT { @Test public void withInFiltersExpressions1() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index 82cd3bfbd..ec099430c 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.Collections; import java.util.List; @@ -23,6 +22,10 @@ import java.util.Map; import java.util.UUID; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -34,9 +37,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -49,6 +51,10 @@ public class ChromaVectorStoreIT { @Container static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -57,29 +63,25 @@ public class ChromaVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); - @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).hasSize(0); @@ -89,7 +91,7 @@ public class ChromaVectorStoreIT { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -129,7 +131,7 @@ public class ChromaVectorStoreIT { public void documentUpdateTest() { // Note ,using OpenAI to calculate embeddings - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -170,11 +172,11 @@ public class ChromaVectorStoreIT { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); @@ -189,14 +191,14 @@ public class ChromaVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); }); } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java index 89e9eab2e..3b39051ee 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -42,13 +48,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -81,13 +82,13 @@ public class ChromaVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { ChromaVectorStore vectorStore = context.getBean(ChromaVectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java index 84d7a0bb4..19ebd4b9d 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; +import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.ChromaImage; import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; @@ -32,9 +35,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.web.client.RestClient; -import org.testcontainers.chromadb.ChromaDBContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; /** * ChromaDB with static API Token Authentication: @@ -69,7 +71,7 @@ public class TokenSecuredChromaWhereIT { @Test public void withInFiltersExpressions1() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -93,7 +95,7 @@ public class TokenSecuredChromaWhereIT { @Test public void withInFiltersExpressions() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 6dea5967c..89cb9e6e6 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java index c8d0701bf..e7b2c5a01 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +package org.springframework.ai.vectorstore; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -27,6 +23,11 @@ import java.util.List; import java.util.TimeZone; import java.util.regex.Pattern; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + /** * ElasticsearchAiSearchFilterExpressionConverter is a class that converts * Filter.Expression objects into Elasticsearch query string representation. It extends diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index bf9932b37..32731f754 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static java.lang.Math.sqrt; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.util.List; @@ -24,10 +23,22 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.core.BulkRequest; +import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; +import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.Version; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; import org.elasticsearch.client.RestClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -45,18 +56,7 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; - -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch.core.BulkRequest; -import co.elastic.clients.elasticsearch.core.BulkResponse; -import co.elastic.clients.elasticsearch.core.SearchResponse; -import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; -import co.elastic.clients.elasticsearch.core.search.Hit; -import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.rest_client.RestClientTransport; -import io.micrometer.observation.ObservationRegistry; +import static java.lang.Math.sqrt; /** * The ElasticsearchVectorStore class implements the VectorStore interface and provides @@ -79,6 +79,10 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class); + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, + VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT); + private final EmbeddingModel embeddingModel; private final ElasticsearchClient elasticsearchClient; @@ -176,14 +180,14 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp try { float threshold = (float) searchRequest.getSimilarityThreshold(); // reverting l2_norm distance to its original value - if (options.getSimilarity().equals(SimilarityFunction.l2_norm)) { + if (this.options.getSimilarity().equals(SimilarityFunction.l2_norm)) { threshold = 1 - threshold; } final float finalThreshold = threshold; float[] vectors = this.embeddingModel.embed(searchRequest.getQuery()); - SearchResponse res = elasticsearchClient.search( - sr -> sr.index(options.getIndexName()) + SearchResponse res = this.elasticsearchClient.search( + sr -> sr.index(this.options.getIndexName()) .knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors)) .similarity(finalThreshold) .k((long) searchRequest.getTopK()) @@ -215,7 +219,7 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp // more info on score/distance calculation // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search private float calculateDistance(Float score) { - switch (options.getSimilarity()) { + switch (this.options.getSimilarity()) { case l2_norm: // the returned value of l2_norm is the opposite of the other functions // (closest to zero means more accurate), so to make it consistent @@ -230,7 +234,7 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp public boolean indexExists() { try { - return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value(); + return this.elasticsearchClient.indices().exists(ex -> ex.index(this.options.getIndexName())).value(); } catch (IOException e) { throw new RuntimeException(e); @@ -240,9 +244,10 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp private void createIndexMapping() { try { this.elasticsearchClient.indices() - .create(cr -> cr.index(options.getIndexName()) - .mappings(map -> map.properties("embedding", p -> p.denseVector( - dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions()))))); + .create(cr -> cr.index(this.options.getIndexName()) + .mappings(map -> map.properties("embedding", + p -> p.denseVector(dv -> dv.similarity(this.options.getSimilarity().toString()) + .dims(this.options.getDimensions()))))); } catch (IOException e) { throw new RuntimeException(e); @@ -267,10 +272,6 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, - VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.options.getSimilarity())) { return this.options.getSimilarity().name(); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java index c68512246..8aee18ac0 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; /** @@ -40,7 +41,7 @@ public class ElasticsearchVectorStoreOptions { private SimilarityFunction similarity = SimilarityFunction.cosine; public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -48,7 +49,7 @@ public class ElasticsearchVectorStoreOptions { } public int getDimensions() { - return dimensions; + return this.dimensions; } public void setDimensions(int dims) { @@ -56,7 +57,7 @@ public class ElasticsearchVectorStoreOptions { } public SimilarityFunction getSimilarity() { - return similarity; + return this.similarity; } public void setSimilarity(SimilarityFunction similarity) { diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java index 86fc84c01..b28e7313d 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; /** diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java index 382209646..7a6f737ee 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +package org.springframework.ai.vectorstore; import java.util.Date; import java.util.List; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; @@ -38,25 +40,25 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void testDate() { - String vectorExpr = converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value(new Date(1704637752148L)))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:2024-01-07T14:29:12Z"); - vectorExpr = converter.convertExpression( + vectorExpr = this.converter.convertExpression( new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value("1970-01-01T00:00:02Z"))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:1970-01-01T00:00:02Z"); } @Test public void testEQ() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country:BG"); } @Test public void tesEqAndGte() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("genre"), new Filter.Value("drama")), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata.genre:drama AND metadata.year:>=2020"); @@ -64,14 +66,14 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void tesIn() { - String vectorExpr = converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), new Filter.Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("(metadata.genre:comedy OR documentary OR drama)"); } @Test public void testNe() { - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")), @@ -81,7 +83,7 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void testGroup() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Group(new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")))), @@ -92,7 +94,7 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void tesBoolean() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("isOpen"), new Filter.Value(true)), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020))), new Filter.Expression(IN, new Filter.Key("country"), new Filter.Value(List.of("BG", "NL", "US"))))); @@ -103,7 +105,7 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void testDecimal() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(GTE, new Filter.Key("temperature"), new Filter.Value(-15.6)), new Filter.Expression(LTE, new Filter.Key("temperature"), new Filter.Value(20.13)))); @@ -112,11 +114,11 @@ class ElasticsearchAiSearchFilterExpressionConverterTest { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("\"country 1 2 3\""), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); - vectorExpr = converter + vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("'country 1 2 3'"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java index 2697c1950..db8b68f3b 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index c81972bac..c262c9a4a 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -121,7 +122,7 @@ class ElasticsearchVectorStoreIT { assertThat(stats.total().docs().count()).isEqualTo(0L); - vectorStore.add(documents); + vectorStore.add(this.documents); elasticsearchClient.indices().refresh(); stats = elasticsearchClient.indices() .stats(s -> s.index("spring-ai-document-index")) @@ -148,7 +149,7 @@ class ElasticsearchVectorStoreIT { ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, ElasticsearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -160,14 +161,14 @@ class ElasticsearchVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore @@ -266,7 +267,7 @@ class ElasticsearchVectorStoreIT { assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)), hasSize(0)); @@ -334,7 +335,7 @@ class ElasticsearchVectorStoreIT { ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, ElasticsearchVectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); SearchRequest query = SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll(); @@ -353,13 +354,13 @@ class ElasticsearchVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch( diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java index 4a0c9cc58..7040b97a6 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,6 +23,15 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.apache.http.HttpHost; import org.awaitility.Awaitility; import org.elasticsearch.client.RestClient; @@ -31,6 +39,10 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.elasticsearch.ElasticsearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -48,21 +60,11 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.elasticsearch.ElasticsearchContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.greaterThan; -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; -import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.rest_client.RestClientTransport; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import static org.hamcrest.Matchers.greaterThan;; +; /** * @author Christian Tzolov @@ -92,10 +94,6 @@ public class ElasticsearchVectorStoreObservationIT { } } - private ApplicationContextRunner getContextRunner() { - return new ApplicationContextRunner().withUserConfiguration(Config.class); - } - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -103,6 +101,10 @@ public class ElasticsearchVectorStoreObservationIT { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private ApplicationContextRunner getContextRunner() { + return new ApplicationContextRunner().withUserConfiguration(Config.class); + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { @@ -124,7 +126,7 @@ public class ElasticsearchVectorStoreObservationIT { TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-gemfire-store/pom.xml b/vector-stores/spring-ai-gemfire-store/pom.xml index e2dcdff04..25ab313c8 100644 --- a/vector-stores/spring-ai-gemfire-store/pom.xml +++ b/vector-stores/spring-ai-gemfire-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 416e55053..7f3390d6d 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -16,17 +16,22 @@ package org.springframework.ai.vectorstore; -import static org.springframework.http.HttpStatus.BAD_REQUEST; -import static org.springframework.http.HttpStatus.NOT_FOUND; - import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.util.annotation.NonNull; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -47,14 +52,8 @@ import org.springframework.web.reactive.function.client.WebClientException; import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.util.UriComponentsBuilder; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.micrometer.observation.ObservationRegistry; -import reactor.util.annotation.NonNull; +import static org.springframework.http.HttpStatus.BAD_REQUEST; +import static org.springframework.http.HttpStatus.NOT_FOUND; /** * A VectorStore implementation backed by GemFire. This store supports creating, updating, diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java index 806497e25..3d204767a 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index 42ecfc338..86c71724a 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static java.util.concurrent.TimeUnit.MINUTES; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -34,6 +31,7 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -43,6 +41,10 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; + /** * @author Geet Rawat * @author Soby Chacko @@ -53,14 +55,22 @@ public class GemFireVectorStoreIT { public static final String INDEX_NAME = "spring-ai-index1"; - private static GemFireCluster gemFireCluster; - private static final int HTTP_SERVICE_PORT = 9090; private static final int LOCATOR_COUNT = 1; private static final int SERVER_COUNT = 1; + private static GemFireCluster gemFireCluster; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @AfterAll public static void stopGemFireCluster() { gemFireCluster.close(); @@ -83,11 +93,6 @@ public class GemFireVectorStoreIT { String.format("localhost[%d]", gemFireCluster.getLocatorPort())); } - List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -98,15 +103,12 @@ public class GemFireVectorStoreIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @Test public void addAndDeleteEmbeddingTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.add(this.documents); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() .atMost(1, MINUTES) .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)), @@ -116,9 +118,9 @@ public class GemFireVectorStoreIT { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .atMost(1, MINUTES) @@ -127,7 +129,7 @@ public class GemFireVectorStoreIT { List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5)); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); @@ -137,7 +139,7 @@ public class GemFireVectorStoreIT { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -175,9 +177,9 @@ public class GemFireVectorStoreIT { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .atMost(1, MINUTES) @@ -198,7 +200,7 @@ public class GemFireVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java index 5cdd50591..abf2374c0 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -41,16 +48,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import com.github.dockerjava.api.model.ExposedPort; -import com.github.dockerjava.api.model.PortBinding; -import com.github.dockerjava.api.model.Ports; -import com.vmware.gemfire.testcontainers.GemFireCluster; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; - import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; /** @@ -62,14 +61,22 @@ public class GemFireVectorStoreObservationIT { public static final String TEST_INDEX_NAME = "spring-ai-index1"; - private static GemFireCluster gemFireCluster; - private static final int HTTP_SERVICE_PORT = 9090; private static final int LOCATOR_COUNT = 1; private static final int SERVER_COUNT = 1; + private static GemFireCluster gemFireCluster; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + @AfterAll public static void stopGemFireCluster() { gemFireCluster.close(); @@ -92,14 +99,6 @@ public class GemFireVectorStoreObservationIT { String.format("localhost[%d]", gemFireCluster.getLocatorPort())); } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - - List documents = List.of( - new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document(getText("classpath:/test/data/time.shelter.txt")), - new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(String uri) { var resource = new DefaultResourceLoader().getResource(uri); try { @@ -113,13 +112,13 @@ public class GemFireVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-hanadb-store/pom.xml b/vector-stores/spring-ai-hanadb-store/pom.xml index a5fb9e346..b794e9164 100644 --- a/vector-stores/spring-ai-hanadb-store/pom.xml +++ b/vector-stores/spring-ai-hanadb-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java index d420a206d..89fbf5e27 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java @@ -15,14 +15,18 @@ */ package org.springframework.ai.vectorstore; -import com.fasterxml.jackson.core.JsonProcessingException; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; import io.micrometer.observation.ObservationRegistry; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.model.EmbeddingUtils; @@ -34,11 +38,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationCont import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - /** * The SAP HANA Cloud vector engine offers multiple use cases in AI scenarios. * diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java index 8e69e66e2..b8b8faff0 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; /** @@ -37,11 +38,11 @@ public class HanaCloudVectorStoreConfig { } public String getTableName() { - return tableName; + return this.tableName; } public int getTopK() { - return topK; + return this.topK; } public static class HanaCloudVectorStoreConfigBuilder { @@ -62,8 +63,8 @@ public class HanaCloudVectorStoreConfig { public HanaCloudVectorStoreConfig build() { HanaCloudVectorStoreConfig config = new HanaCloudVectorStoreConfig(); - config.tableName = tableName; - config.topK = topK; + config.tableName = this.tableName; + config.topK = this.topK; return config; } diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java index 439b4f88f..7b7109c6a 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorEntity.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import jakarta.persistence.Column; @@ -39,7 +40,7 @@ public abstract class HanaVectorEntity { } public String get_id() { - return _id; + return this._id; } } diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java index 16e33285c..e962006f1 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaVectorRepository.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java index ed1f1e393..8853ac37f 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCup.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import jakarta.persistence.Column; @@ -31,7 +32,7 @@ public class CricketWorldCup extends HanaVectorEntity { private String content; public String getContent() { - return content; + return this.content; } } diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java index 392e876b4..48d97dd26 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.model.ChatModel; + import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; @@ -33,13 +42,6 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; - /** * @author Rahul Mittal * @since 1.0.0 @@ -74,7 +76,7 @@ public class CricketWorldCupHanaController { Function, List> splitter = new TokenTextSplitter(); List documents = splitter.apply(reader.get()); logger.info("{} documents created from pdf file: {}", documents.size(), pdf.getFilename()); - hanaCloudVectorStore.accept(documents); + this.hanaCloudVectorStore.accept(documents); return ResponseEntity.ok() .body(String.format("%d documents created from pdf file: %s", documents.size(), pdf.getFilename())); } @@ -88,7 +90,7 @@ public class CricketWorldCupHanaController { var userMessage = new UserMessage(message); Prompt prompt = new Prompt(List.of(similarDocsMessage, userMessage)); - String generation = chatModel.call(prompt).getResult().getOutput().getContent(); + String generation = this.chatModel.call(prompt).getResult().getOutput().getContent(); logger.info("Generation: {}", generation); return Map.of("generation", generation); } diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java index 397ea39be..2a9aac568 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupRepository.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import jakarta.persistence.EntityManager; import jakarta.persistence.PersistenceContext; import jakarta.transaction.Transactional; -import org.springframework.stereotype.Repository; -import java.util.List; +import org.springframework.stereotype.Repository; /** * @author Rahul Mittal @@ -40,7 +42,7 @@ public class CricketWorldCupRepository implements HanaVectorRepository { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(HanaCloudVectorStore.class); int deleteCount = ((HanaCloudVectorStore) vectorStore).purgeEmbeddings(); @@ -128,4 +129,4 @@ public class HanaCloudVectorStoreIT { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java index 78c324a4d..c61fb2aaf 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/HanaVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,12 @@ import java.util.Map; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -46,9 +49,7 @@ import org.springframework.orm.jpa.JpaVendorAdapter; import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean; import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -62,6 +63,9 @@ public class HanaVectorStoreObservationIT { private static final String TEST_TABLE_NAME = "CRICKET_WORLD_CUP"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -77,19 +81,16 @@ public class HanaVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties index faf788a4a..f2d9b9274 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties @@ -1,3 +1,19 @@ +# +# Copyright 2023-2024 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spring.ai.openai.api-key=${OPENAI_API_KEY} spring.ai.openai.embedding.options.model=text-embedding-ada-002 diff --git a/vector-stores/spring-ai-milvus-store/pom.xml b/vector-stores/spring-ai-milvus-store/pom.xml index 8fe98ecde..bdf72f777 100644 --- a/vector-stores/spring-ai-milvus-store/pom.xml +++ b/vector-stores/spring-ai-milvus-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java index c3d6d1a2d..c7e9a0939 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter.Expression; diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index b1141949a..fe9114234 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + import com.alibaba.fastjson.JSONObject; import io.micrometer.observation.ObservationRegistry; import io.milvus.client.MilvusServiceClient; @@ -44,6 +51,7 @@ import io.milvus.response.QueryResultsWrapper.RowRecord; import io.milvus.response.SearchResultsWrapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -60,12 +68,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - /** * @author Christian Tzolov * @author Soby Chacko @@ -73,8 +75,6 @@ import java.util.stream.Collectors; */ public class MilvusVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); - public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536; public static final int INVALID_EMBEDDING_DIMENSION = -1; @@ -97,6 +97,12 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements public static final List SEARCH_OUTPUT_FIELDS = List.of(DOC_ID_FIELD_NAME, CONTENT_FIELD_NAME, METADATA_FIELD_NAME); + private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); + + private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, + VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, + VectorStoreSimilarityMetric.DOT); + public final FilterExpressionConverter filterExpressionConverter = new MilvusFilterExpressionConverter(); private final MilvusServiceClient milvusClient; @@ -109,151 +115,6 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements private final BatchingStrategy batchingStrategy; - /** - * Configuration for the Milvus vector store. - */ - public static class MilvusVectorStoreConfig { - - private final String databaseName; - - private final String collectionName; - - private final int embeddingDimension; - - private final IndexType indexType; - - private final MetricType metricType; - - private final String indexParameters; - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - /** - * {@return the default config} - */ - public static MilvusVectorStoreConfig defaultConfig() { - return builder().build(); - } - - private MilvusVectorStoreConfig(Builder builder) { - this.databaseName = builder.databaseName; - this.collectionName = builder.collectionName; - this.embeddingDimension = builder.embeddingDimension; - this.indexType = builder.indexType; - this.metricType = builder.metricType; - this.indexParameters = builder.indexParameters; - } - - public static class Builder { - - private String databaseName = DEFAULT_DATABASE_NAME; - - private String collectionName = DEFAULT_COLLECTION_NAME; - - private int embeddingDimension = INVALID_EMBEDDING_DIMENSION; - - private IndexType indexType = IndexType.IVF_FLAT; - - private MetricType metricType = MetricType.COSINE; - - private String indexParameters = "{\"nlist\":1024}"; - - private Builder() { - } - - /** - * Configures the Milvus metric type to use. Leave {@literal null} or blank to - * use the metric metric: https://milvus.io/docs/metric.md#floating - * @param metricType the metric type to use - * @return this builder - */ - public Builder withMetricType(MetricType metricType) { - Assert.notNull(metricType, "Collection Name must not be empty"); - Assert.isTrue( - metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE, - "Only the text metric types IP and L2 are supported"); - - this.metricType = metricType; - return this; - } - - /** - * Configures the Milvus index type to use. Leave {@literal null} or blank to - * use the default index. - * @param indexType the index type to use - * @return this builder - */ - public Builder withIndexType(IndexType indexType) { - this.indexType = indexType; - return this; - } - - /** - * Configures the Milvus index parameters to use. Leave {@literal null} or - * blank to use the default index parameters. - * @param indexParameters the index parameters to use - * @return this builder - */ - public Builder withIndexParameters(String indexParameters) { - this.indexParameters = indexParameters; - return this; - } - - /** - * Configures the Milvus database name to use. Leave {@literal null} or blank - * to use the default database. - * @param databaseName the database name to use - * @return this builder - */ - public Builder withDatabaseName(String databaseName) { - this.databaseName = databaseName; - return this; - } - - /** - * Configures the Milvus collection name to use. Leave {@literal null} or - * blank to use the default collection name. - * @param collectionName the collection name to use - * @return this builder - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * Configures the size of the embedding. Defaults to {@literal 1536}, inline - * with OpenAIs embeddings. - * @param newEmbeddingDimension The dimension of the embedding - * @return this builder - */ - public Builder withEmbeddingDimension(int newEmbeddingDimension) { - - Assert.isTrue(newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768, - "Dimension has to be withing the boundaries 1 and 32768 (inclusively)"); - - this.embeddingDimension = newEmbeddingDimension; - return this; - } - - /** - * {@return the immutable configuration} - */ - public MilvusVectorStoreConfig build() { - return new MilvusVectorStoreConfig(this); - } - - } - - } - public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, @@ -369,7 +230,7 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements searchParamBuilder.withExpr(nativeFilterExpressions); } - R respSearch = milvusClient.search(searchParamBuilder.build()); + R respSearch = this.milvusClient.search(searchParamBuilder.build()); if (respSearch.getException() != null) { throw new RuntimeException("Search failed!", respSearch.getException()); @@ -558,10 +419,6 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements .withNamespace(this.config.databaseName); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, - VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, - VectorStoreSimilarityMetric.DOT); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.metricType)) { return this.config.metricType.name(); @@ -569,4 +426,149 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements return SIMILARITY_TYPE_MAPPING.get(this.config.metricType).value(); } + /** + * Configuration for the Milvus vector store. + */ + public static class MilvusVectorStoreConfig { + + private final String databaseName; + + private final String collectionName; + + private final int embeddingDimension; + + private final IndexType indexType; + + private final MetricType metricType; + + private final String indexParameters; + + private MilvusVectorStoreConfig(Builder builder) { + this.databaseName = builder.databaseName; + this.collectionName = builder.collectionName; + this.embeddingDimension = builder.embeddingDimension; + this.indexType = builder.indexType; + this.metricType = builder.metricType; + this.indexParameters = builder.indexParameters; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static MilvusVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String databaseName = DEFAULT_DATABASE_NAME; + + private String collectionName = DEFAULT_COLLECTION_NAME; + + private int embeddingDimension = INVALID_EMBEDDING_DIMENSION; + + private IndexType indexType = IndexType.IVF_FLAT; + + private MetricType metricType = MetricType.COSINE; + + private String indexParameters = "{\"nlist\":1024}"; + + private Builder() { + } + + /** + * Configures the Milvus metric type to use. Leave {@literal null} or blank to + * use the metric metric: https://milvus.io/docs/metric.md#floating + * @param metricType the metric type to use + * @return this builder + */ + public Builder withMetricType(MetricType metricType) { + Assert.notNull(metricType, "Collection Name must not be empty"); + Assert.isTrue( + metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE, + "Only the text metric types IP and L2 are supported"); + + this.metricType = metricType; + return this; + } + + /** + * Configures the Milvus index type to use. Leave {@literal null} or blank to + * use the default index. + * @param indexType the index type to use + * @return this builder + */ + public Builder withIndexType(IndexType indexType) { + this.indexType = indexType; + return this; + } + + /** + * Configures the Milvus index parameters to use. Leave {@literal null} or + * blank to use the default index parameters. + * @param indexParameters the index parameters to use + * @return this builder + */ + public Builder withIndexParameters(String indexParameters) { + this.indexParameters = indexParameters; + return this; + } + + /** + * Configures the Milvus database name to use. Leave {@literal null} or blank + * to use the default database. + * @param databaseName the database name to use + * @return this builder + */ + public Builder withDatabaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + /** + * Configures the Milvus collection name to use. Leave {@literal null} or + * blank to use the default collection name. + * @param collectionName the collection name to use + * @return this builder + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Configures the size of the embedding. Defaults to {@literal 1536}, inline + * with OpenAIs embeddings. + * @param newEmbeddingDimension The dimension of the embedding + * @return this builder + */ + public Builder withEmbeddingDimension(int newEmbeddingDimension) { + + Assert.isTrue(newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768, + "Dimension has to be withing the boundaries 1 and 32768 (inclusively)"); + + this.embeddingDimension = newEmbeddingDimension; + return this; + } + + /** + * {@return the immutable configuration} + */ + public MilvusVectorStoreConfig build() { + return new MilvusVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java index 78538ee8f..bc60af7fa 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import io.milvus.client.MilvusServiceClient; @@ -57,38 +58,40 @@ public class MilvusEmbeddingDimensionsTests { .withEmbeddingDimension(explicitDimensions) .build(); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, config, true, + new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(explicitDimensions); - verify(embeddingModel, never()).dimensions(); + verify(this.embeddingModel, never()).dimensions(); } @Test public void embeddingModelDimensions() { - when(embeddingModel.dimensions()).thenReturn(969); + when(this.embeddingModel.dimensions()).thenReturn(969); MilvusVectorStoreConfig config = MilvusVectorStoreConfig.builder().build(); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, config, true, + new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(969); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @Test public void fallBackToDefaultDimensions() { - when(embeddingModel.dimensions()).thenThrow(new RuntimeException()); + when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); - var dim = new MilvusVectorStore(milvusClient, embeddingModel, MilvusVectorStoreConfig.builder().build(), true, - new TokenCountBatchingStrategy()) + var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, + MilvusVectorStoreConfig.builder().build(), true, new TokenCountBatchingStrategy()) .embeddingDimensions(); assertThat(dim).isEqualTo(MilvusVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @ParameterizedTest diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java index afab2766a..dd0b8ef82 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; @@ -45,14 +46,14 @@ public class MilvusFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country\"] == \"BG\""); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata[\"genre\"] == \"drama\" && metadata[\"year\"] >= 2020"); @@ -61,7 +62,7 @@ public class MilvusFilterExpressionConverterTests { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("metadata[\"genre\"] in [\"comedy\",\"documentary\",\"drama\"]"); } @@ -69,7 +70,7 @@ public class MilvusFilterExpressionConverterTests { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -80,7 +81,7 @@ public class MilvusFilterExpressionConverterTests { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -91,7 +92,7 @@ public class MilvusFilterExpressionConverterTests { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -103,7 +104,7 @@ public class MilvusFilterExpressionConverterTests { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -112,11 +113,11 @@ public class MilvusFilterExpressionConverterTests { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country 1 2 3\"] == \"BG\""); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata[\"country 1 2 3\"] == \"BG\""); } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java index ffdcd3c4b..8212474bd 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index 118b96c67..98a88e7b7 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -31,6 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -45,7 +47,6 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.milvus.MilvusContainer; import static org.assertj.core.api.Assertions.assertThat; @@ -88,30 +89,31 @@ public class MilvusVectorStoreIT { @ValueSource(strings = { "COSINE", "L2", "IP" }) public void addAndSearch(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); - assertThat(resultDoc.getContent()).contains( - "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - assertThat(results).hasSize(0); - }); + results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + assertThat(results).hasSize(0); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -121,135 +123,140 @@ public class MilvusVectorStoreIT { // https://milvus.io/docs/json_data_type.md - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "BG", "year", 2020)); - var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "NL")); - var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", - Map.of("country", "BG", "year", 2023)); + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020)); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); - vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); - List results = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); - assertThat(results).hasSize(3); + List results = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); + assertThat(results).hasSize(3); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'NL'")); - assertThat(results).hasSize(1); - assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'NL'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'BG'")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG'")); - assertThat(results).hasSize(2); - assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); - assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'BG' && year == 2020")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG' && year == 2020")); - assertThat(results).hasSize(1); - assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("NOT(country == 'BG' && year == 2020)")); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'BG' && year == 2020)")); - assertThat(results).hasSize(2); - assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); - assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); - }); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "L2", "IP" }) public void documentUpdate(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", - Collections.singletonMap("meta1", "meta1")); + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Collections.singletonMap("meta1", "meta1")); - vectorStore.add(List.of(document)); + vectorStore.add(List.of(document)); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(document.getId()); - assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); - Document sameIdDocument = new Document(document.getId(), - "The World is Big and Salvation Lurks Around the Corner", - Collections.singletonMap("meta2", "meta2")); + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", + Collections.singletonMap("meta2", "meta2")); - vectorStore.add(List.of(sameIdDocument)); + vectorStore.add(List.of(sameIdDocument)); - results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); - assertThat(results).hasSize(1); - resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(document.getId()); - assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); - vectorStore.delete(List.of(document.getId())); + vectorStore.delete(List.of(document.getId())); - }); + }); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "IP" }) public void searchWithThreshold(String metricType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType).run(context -> { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType) + .run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean(VectorStore.class); - resetCollection(vectorStore); + resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); - List fullResult = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); - assertThat(distances).hasSize(3); + assertThat(distances).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + float threshold = (distances.get(0) + distances.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); - assertThat(resultDoc.getContent()).contains( - "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - }); + }); } @SpringBootConfiguration @@ -265,7 +272,7 @@ public class MilvusVectorStoreIT { .withCollectionName("test_vector_store") .withDatabaseName("default") .withIndexType(IndexType.IVF_FLAT) - .withMetricType(metricType) + .withMetricType(this.metricType) .build(); return new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()); } @@ -288,4 +295,4 @@ public class MilvusVectorStoreIT { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java index 4abca7875..3a8a7bc7e 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.milvus.client.MilvusServiceClient; +import io.milvus.param.ConnectParam; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -40,17 +50,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.milvus.MilvusContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.milvus.client.MilvusServiceClient; -import io.milvus.param.ConnectParam; -import io.milvus.param.IndexType; -import io.milvus.param.MetricType; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -85,13 +86,13 @@ public class MilvusVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml index c9bde1048..48689d36d 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml +++ b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java index be2705c64..c72f51c4f 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index aa16e3f29..0cb9f974b 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,14 +16,15 @@ package org.springframework.ai.vectorstore; -import static org.springframework.data.mongodb.core.query.Criteria.where; - import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import com.mongodb.MongoCommandException; +import io.micrometer.observation.ObservationRegistry; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -42,9 +43,7 @@ import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.util.Assert; -import com.mongodb.MongoCommandException; - -import io.micrometer.observation.ObservationRegistry; +import static org.springframework.data.mongodb.core.query.Criteria.where; /** * @author Chris Smith @@ -119,8 +118,8 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl } // Create the collection if it does not exist - if (!mongoTemplate.collectionExists(this.config.collectionName)) { - mongoTemplate.createCollection(this.config.collectionName); + if (!this.mongoTemplate.collectionExists(this.config.collectionName)) { + this.mongoTemplate.createCollection(this.config.collectionName); } // Create search index createSearchIndex(); @@ -128,7 +127,7 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl private void createSearchIndex() { try { - mongoTemplate.executeCommand(createSearchIndexDefinition()); + this.mongoTemplate.executeCommand(createSearchIndexDefinition()); } catch (UncategorizedMongoDbException e) { Throwable cause = e.getCause(); @@ -228,6 +227,15 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl .toList(); } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.MONGODB.value(), operationName) + .withCollectionName(this.config.collectionName) + .withDimensions(this.embeddingModel.dimensions()) + .withFieldName(this.config.pathName); + } + public static class MongoDBVectorStoreConfig { private final String collectionName; @@ -324,13 +332,4 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.MONGODB.value(), operationName) - .withCollectionName(this.config.collectionName) - .withDimensions(this.embeddingModel.dimensions()) - .withFieldName(this.config.pathName); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java index 6888b9e11..b741193dd 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/VectorSearchAggregation.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; import org.bson.Document; + import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.lang.NonNull; @@ -28,15 +30,16 @@ record VectorSearchAggregation(List embeddings, String path, int numCandi @SuppressWarnings("null") @Override public org.bson.Document toDocument(@NonNull AggregationOperationContext context) { - var vectorSearch = new Document("queryVector", embeddings).append("path", path) - .append("numCandidates", numCandidates) - .append("index", index) - .append("limit", count); - if (!filter.isEmpty()) { - vectorSearch.append("filter", Document.parse(filter)); + var vectorSearch = new Document("queryVector", this.embeddings).append("path", this.path) + .append("numCandidates", this.numCandidates) + .append("index", this.index) + .append("limit", this.count); + if (!this.filter.isEmpty()) { + vectorSearch.append("filter", Document.parse(this.filter)); } var doc = new Document("$vectorSearch", vectorSearch); return context.getMappedObject(doc); } -} \ No newline at end of file + +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java index 6ab38c55c..a8df6929e 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasFilterConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; @@ -45,14 +46,14 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country\":{$eq:\"BG\"}}"); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr) @@ -62,7 +63,7 @@ public class MongoDBAtlasFilterConverterTest { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("{\"metadata.genre\":{$in:[\"comedy\",\"documentary\",\"drama\"]}}"); } @@ -70,7 +71,7 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -81,7 +82,7 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -92,7 +93,7 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -104,7 +105,7 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -114,11 +115,11 @@ public class MongoDBAtlasFilterConverterTest { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country 1 2 3\":{$eq:\"BG\"}}"); - vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("{\"metadata.country 1 2 3\":{$eq:\"BG\"}}"); } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java index 45c8a140b..5ec855ec0 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + import com.mongodb.client.MongoClient; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -34,17 +45,6 @@ import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.util.MimeType; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -66,7 +66,7 @@ class MongoDBAtlasVectorStoreIT { @BeforeEach public void beforeEach() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MongoTemplate mongoTemplate = context.getBean(MongoTemplate.class); mongoTemplate.getCollection("vector_store").deleteMany(new org.bson.Document()); }); @@ -74,7 +74,7 @@ class MongoDBAtlasVectorStoreIT { @Test void vectorStoreTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); List documents = List.of( @@ -109,7 +109,7 @@ class MongoDBAtlasVectorStoreIT { @Test void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", @@ -144,7 +144,7 @@ class MongoDBAtlasVectorStoreIT { @Test void searchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -228,6 +228,7 @@ class MongoDBAtlasVectorStoreIT { @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -238,6 +239,7 @@ class MongoDBAtlasVectorStoreIT { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); @@ -253,4 +255,4 @@ class MongoDBAtlasVectorStoreIT { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java index ed59c4993..946e813e4 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java index 3c78d31ea..b8ddad7b4 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -23,9 +22,17 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import com.mongodb.client.MongoClient; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -48,15 +55,7 @@ import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.util.MimeType; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import com.mongodb.client.MongoClient; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -93,7 +92,7 @@ public class MongoDbVectorStoreObservationIT { @BeforeEach public void beforeEach() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { MongoTemplate mongoTemplate = context.getBean(MongoTemplate.class); mongoTemplate.getCollection("vector_store").deleteMany(new org.bson.Document()); }); @@ -102,13 +101,13 @@ public class MongoDbVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Thread.sleep(5000); @@ -212,6 +211,7 @@ public class MongoDbVectorStoreObservationIT { @Bean public Converter mimeTypeToStringConverter() { return new Converter() { + @Override public String convert(MimeType source) { return source.toString(); @@ -222,6 +222,7 @@ public class MongoDbVectorStoreObservationIT { @Bean public Converter stringToMimeTypeConverter() { return new Converter() { + @Override public MimeType convert(String source) { return MimeType.valueOf(source); diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java index d217eb54b..3e6cce348 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/VectorSearchAggregationTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import org.bson.Document; import org.junit.jupiter.api.Test; -import org.springframework.data.mongodb.core.aggregation.Aggregation; -import java.util.List; +import org.springframework.data.mongodb.core.aggregation.Aggregation; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -60,4 +62,4 @@ class VectorSearchAggregationTest { assertEquals(expected, document); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-neo4j-store/pom.xml b/vector-stores/spring-ai-neo4j-store/pom.xml index ae1b5b3af..913277180 100644 --- a/vector-stores/spring-ai-neo4j-store/pom.xml +++ b/vector-stores/spring-ai-neo4j-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 55c169d24..6d938a801 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -22,10 +22,12 @@ import java.util.Map; import java.util.Optional; import java.util.function.Predicate; +import io.micrometer.observation.ObservationRegistry; import org.neo4j.cypherdsl.support.schema_name.SchemaNames; import org.neo4j.driver.Driver; import org.neo4j.driver.SessionConfig; import org.neo4j.driver.Values; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -40,8 +42,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - /** * @author Gerrit Meier * @author Michael Simons @@ -51,222 +51,6 @@ import io.micrometer.observation.ObservationRegistry; */ public class Neo4jVectorStore extends AbstractObservationVectorStore implements InitializingBean { - /** - * An enum to configure the distance function used in the Neo4j vector index. - */ - public enum Neo4jDistanceType { - - COSINE("cosine"), EUCLIDEAN("euclidean"); - - public final String name; - - Neo4jDistanceType(String name) { - this.name = name; - } - - } - - /** - * Configuration for the Neo4j vector store. - */ - public static final class Neo4jVectorStoreConfig { - - private final SessionConfig sessionConfig; - - private final int embeddingDimension; - - private final Neo4jDistanceType distanceType; - - private final String embeddingProperty; - - private final String label; - - private final String indexName; - - // needed for similarity search call - private final String indexNameNotSanitized; - - private final String idProperty; - - private final String constraintName; - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - /** - * {@return the default config} - */ - public static Neo4jVectorStoreConfig defaultConfig() { - - return builder().build(); - } - - private Neo4jVectorStoreConfig(Builder builder) { - - this.sessionConfig = Optional.ofNullable(builder.databaseName) - .filter(Predicate.not(String::isBlank)) - .map(SessionConfig::forDatabase) - .orElseGet(SessionConfig::defaultConfig); - this.embeddingDimension = builder.embeddingDimension; - this.distanceType = builder.distanceType; - this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow(); - this.label = SchemaNames.sanitize(builder.label).orElseThrow(); - this.indexNameNotSanitized = builder.indexName; - this.indexName = SchemaNames.sanitize(builder.indexName, true).orElseThrow(); - this.constraintName = SchemaNames.sanitize(builder.constraintName).orElseThrow(); - this.idProperty = SchemaNames.sanitize(builder.idProperty).orElseThrow(); - } - - public static class Builder { - - private String databaseName; - - private int embeddingDimension = DEFAULT_EMBEDDING_DIMENSION; - - private Neo4jDistanceType distanceType = Neo4jDistanceType.COSINE; - - private String label = DEFAULT_LABEL; - - private String embeddingProperty = DEFAULT_EMBEDDING_PROPERTY; - - private String indexName = DEFAULT_INDEX_NAME; - - private String idProperty = DEFAULT_ID_PROPERTY; - - private String constraintName = DEFAULT_CONSTRAINT_NAME; - - private Builder() { - } - - /** - * Configures the Neo4j database name to use. Leave {@literal null} or blank - * to use the default database. - * @param databaseName the database name to use - * @return this builder - */ - public Builder withDatabaseName(String databaseName) { - this.databaseName = databaseName; - return this; - } - - /** - * Configures the size of the embedding. Defaults to {@literal 1536}, inline - * with OpenAIs embeddings. - * @param newEmbeddingDimension The dimension of the embedding - * @return this builder - */ - public Builder withEmbeddingDimension(int newEmbeddingDimension) { - - Assert.isTrue(newEmbeddingDimension >= 1, "Dimension has to be positive."); - - this.embeddingDimension = newEmbeddingDimension; - return this; - } - - /** - * Configures the distance type to store in the index and to use in queries. - * @param newDistanceType The distance type, must not be {@literal null} - * @return this builder - */ - public Builder withDistanceType(Neo4jDistanceType newDistanceType) { - - Assert.notNull(newDistanceType, "Distance type may not be null"); - - this.distanceType = newDistanceType; - return this; - } - - /** - * Configures the node label to use for storing documents. Defaults to - * {@literal Document}. - * @param newLabel The label used on the nodes representing the document - * @return this builder - */ - public Builder withLabel(String newLabel) { - - Assert.hasText(newLabel, "Content label may not be null or blank"); - - this.label = newLabel; - return this; - } - - /** - * Configures the property of the node to use for storing embedding. Defaults - * to {@literal embedding}. - * @param newEmbeddingProperty The property of the nodes for storing the - * embedding - * @return this builder - */ - public Builder withEmbeddingProperty(String newEmbeddingProperty) { - - Assert.hasText(newEmbeddingProperty, "Embedding property may not be null or blank"); - - this.embeddingProperty = newEmbeddingProperty; - return this; - } - - /** - * Configures the vector index to be used. Defaults to - * {@literal spring-ai-document-index}. - * @param newIndexName The name of the index to be used for storing and - * searching data. - * @return this builder - */ - public Builder withIndexName(String newIndexName) { - - Assert.hasText(newIndexName, "Index name may not be null or blank"); - - this.indexName = newIndexName; - return this; - } - - /** - * Configures the id property to be used. Defaults to {@literal id}. - * @param newIdProperty The name of the id property of the {@link Document} - * entity - * @return this builder - */ - public Builder withIdProperty(String newIdProperty) { - - Assert.hasText(newIdProperty, "Id property may not be null or blank"); - - this.idProperty = newIdProperty; - return this; - } - - /** - * Configures the constraint name to be used. Defaults to - * {@literal Document_unique_idx}. - * @param newConstraintName The name of the unique constraint for the id - * property. - * @return this builder - */ - public Builder withConstraintName(String newConstraintName) { - - Assert.hasText(newConstraintName, "Constraint name may not be null or blank"); - - this.constraintName = newConstraintName; - return this; - } - - /** - * {@return the immutable configuration} - */ - public Neo4jVectorStoreConfig build() { - - return new Neo4jVectorStoreConfig(this); - } - - } - - } - public static final int DEFAULT_EMBEDDING_DIMENSION = 1536; public static final String DEFAULT_LABEL = "Document"; @@ -279,6 +63,10 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements public static final String DEFAULT_CONSTRAINT_NAME = DEFAULT_LABEL + "_unique_idx"; + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, + VectorStoreSimilarityMetric.EUCLIDEAN); + private final Neo4jVectorFilterExpressionConverter filterExpressionConverter = new Neo4jVectorFilterExpressionConverter(); private final Driver driver; @@ -445,10 +233,6 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements .withSimilarityMetric(getSimilarityMetric()); } - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - Neo4jDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, Neo4jDistanceType.EUCLIDEAN, - VectorStoreSimilarityMetric.EUCLIDEAN); - private String getSimilarityMetric() { if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.distanceType)) { return this.config.distanceType.name(); @@ -456,4 +240,220 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements return SIMILARITY_TYPE_MAPPING.get(this.config.distanceType).value(); } -} \ No newline at end of file + /** + * An enum to configure the distance function used in the Neo4j vector index. + */ + public enum Neo4jDistanceType { + + COSINE("cosine"), EUCLIDEAN("euclidean"); + + public final String name; + + Neo4jDistanceType(String name) { + this.name = name; + } + + } + + /** + * Configuration for the Neo4j vector store. + */ + public static final class Neo4jVectorStoreConfig { + + private final SessionConfig sessionConfig; + + private final int embeddingDimension; + + private final Neo4jDistanceType distanceType; + + private final String embeddingProperty; + + private final String label; + + private final String indexName; + + // needed for similarity search call + private final String indexNameNotSanitized; + + private final String idProperty; + + private final String constraintName; + + private Neo4jVectorStoreConfig(Builder builder) { + + this.sessionConfig = Optional.ofNullable(builder.databaseName) + .filter(Predicate.not(String::isBlank)) + .map(SessionConfig::forDatabase) + .orElseGet(SessionConfig::defaultConfig); + this.embeddingDimension = builder.embeddingDimension; + this.distanceType = builder.distanceType; + this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow(); + this.label = SchemaNames.sanitize(builder.label).orElseThrow(); + this.indexNameNotSanitized = builder.indexName; + this.indexName = SchemaNames.sanitize(builder.indexName, true).orElseThrow(); + this.constraintName = SchemaNames.sanitize(builder.constraintName).orElseThrow(); + this.idProperty = SchemaNames.sanitize(builder.idProperty).orElseThrow(); + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static Neo4jVectorStoreConfig defaultConfig() { + + return builder().build(); + } + + public static class Builder { + + private String databaseName; + + private int embeddingDimension = DEFAULT_EMBEDDING_DIMENSION; + + private Neo4jDistanceType distanceType = Neo4jDistanceType.COSINE; + + private String label = DEFAULT_LABEL; + + private String embeddingProperty = DEFAULT_EMBEDDING_PROPERTY; + + private String indexName = DEFAULT_INDEX_NAME; + + private String idProperty = DEFAULT_ID_PROPERTY; + + private String constraintName = DEFAULT_CONSTRAINT_NAME; + + private Builder() { + } + + /** + * Configures the Neo4j database name to use. Leave {@literal null} or blank + * to use the default database. + * @param databaseName the database name to use + * @return this builder + */ + public Builder withDatabaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + /** + * Configures the size of the embedding. Defaults to {@literal 1536}, inline + * with OpenAIs embeddings. + * @param newEmbeddingDimension The dimension of the embedding + * @return this builder + */ + public Builder withEmbeddingDimension(int newEmbeddingDimension) { + + Assert.isTrue(newEmbeddingDimension >= 1, "Dimension has to be positive."); + + this.embeddingDimension = newEmbeddingDimension; + return this; + } + + /** + * Configures the distance type to store in the index and to use in queries. + * @param newDistanceType The distance type, must not be {@literal null} + * @return this builder + */ + public Builder withDistanceType(Neo4jDistanceType newDistanceType) { + + Assert.notNull(newDistanceType, "Distance type may not be null"); + + this.distanceType = newDistanceType; + return this; + } + + /** + * Configures the node label to use for storing documents. Defaults to + * {@literal Document}. + * @param newLabel The label used on the nodes representing the document + * @return this builder + */ + public Builder withLabel(String newLabel) { + + Assert.hasText(newLabel, "Content label may not be null or blank"); + + this.label = newLabel; + return this; + } + + /** + * Configures the property of the node to use for storing embedding. Defaults + * to {@literal embedding}. + * @param newEmbeddingProperty The property of the nodes for storing the + * embedding + * @return this builder + */ + public Builder withEmbeddingProperty(String newEmbeddingProperty) { + + Assert.hasText(newEmbeddingProperty, "Embedding property may not be null or blank"); + + this.embeddingProperty = newEmbeddingProperty; + return this; + } + + /** + * Configures the vector index to be used. Defaults to + * {@literal spring-ai-document-index}. + * @param newIndexName The name of the index to be used for storing and + * searching data. + * @return this builder + */ + public Builder withIndexName(String newIndexName) { + + Assert.hasText(newIndexName, "Index name may not be null or blank"); + + this.indexName = newIndexName; + return this; + } + + /** + * Configures the id property to be used. Defaults to {@literal id}. + * @param newIdProperty The name of the id property of the {@link Document} + * entity + * @return this builder + */ + public Builder withIdProperty(String newIdProperty) { + + Assert.hasText(newIdProperty, "Id property may not be null or blank"); + + this.idProperty = newIdProperty; + return this; + } + + /** + * Configures the constraint name to be used. Defaults to + * {@literal Document_unique_idx}. + * @param newConstraintName The name of the unique constraint for the id + * property. + * @return this builder + */ + public Builder withConstraintName(String newConstraintName) { + + Assert.hasText(newConstraintName, "Constraint name may not be null or blank"); + + this.constraintName = newConstraintName; + return this; + } + + /** + * {@return the immutable configuration} + */ + public Neo4jVectorStoreConfig build() { + + return new Neo4jVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java index 41ce4f800..732169999 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java index 3ea60478d..513fd6943 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java index 57fcc7179..3aa438239 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Collections; @@ -27,15 +28,15 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; -import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.testcontainers.containers.Neo4jContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -57,6 +58,9 @@ class Neo4jVectorStoreIT { @Container static Neo4jContainer neo4jContainer = new Neo4jContainer<>(Neo4jImage.DEFAULT_IMAGE).withRandomPassword(); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -65,9 +69,6 @@ class Neo4jVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @BeforeEach void cleanDatabase() { this.contextRunner diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java index ebf09454d..1c841b1c1 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -29,6 +31,10 @@ import org.neo4j.cypherdsl.support.schema_name.SchemaNames; import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -46,13 +52,8 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.Neo4jContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -92,13 +93,13 @@ public class Neo4jVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java index 8c8a397f8..4eeaa22f1 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.filter; import java.util.List; @@ -46,14 +47,14 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testEQ() { // country = "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("node.`metadata.country` = \"BG\""); } @Test public void tesEqAndGte() { // genre = "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("node.`metadata.genre` = \"drama\" AND node.`metadata.year` >= 2020"); @@ -62,7 +63,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("node.`metadata.genre` IN [\"comedy\",\"documentary\",\"drama\"]"); } @@ -70,7 +71,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void tesNIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(NIN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("NOT node.`metadata.genre` IN [\"comedy\",\"documentary\",\"drama\"]"); } @@ -78,7 +79,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testNe() { // year >= 2020 OR country = "BG" AND city <> "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -89,7 +90,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testGroup() { // (year >= 2020 OR country = "BG") AND NOT city IN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NOT, new Expression(IN, new Key("city"), new Value(List.of("Sofia", "Plovdiv")))))); @@ -100,7 +101,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testBoolean() { // isOpen = true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -112,7 +113,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testDecimal() { // temperature >= -15.6 AND temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -122,7 +123,7 @@ public class Neo4jVectorFilterExpressionConverterTests { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("node.`metadata.country 1 2 3` = \"BG\""); } @@ -131,7 +132,7 @@ public class Neo4jVectorFilterExpressionConverterTests { public void testComplexIdentifiers2() { Filter.Expression expr = new FilterExpressionTextParser() .parse("author in ['john', 'jill'] && 'article_type' == 'blog'"); - String vectorExpr = converter.convertExpression(expr); + String vectorExpr = this.converter.convertExpression(expr); assertThat(vectorExpr) .isEqualTo("node.`metadata.author` IN [\"john\",\"jill\"] AND node.`metadata.'article_type'` = \"blog\""); } diff --git a/vector-stores/spring-ai-opensearch-store/pom.xml b/vector-stores/spring-ai-opensearch-store/pom.xml index 33deb0fdc..aa3ddfbbd 100644 --- a/vector-stores/spring-ai-opensearch-store/pom.xml +++ b/vector-stores/spring-ai-opensearch-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java index 9035a86d2..98876f5e6 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +package org.springframework.ai.vectorstore; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -27,6 +23,11 @@ import java.util.List; import java.util.TimeZone; import java.util.regex.Pattern; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + /** * @author Jemin Huh * @since 1.0.0 diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 1756ded2f..f81b108cc 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,14 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.io.StringReader; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import io.micrometer.observation.ObservationRegistry; import org.opensearch.client.json.JsonData; import org.opensearch.client.json.JsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; @@ -30,6 +38,7 @@ import org.opensearch.client.opensearch.indices.CreateIndexResponse; import org.opensearch.client.transport.endpoints.BooleanResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -46,15 +55,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import io.micrometer.observation.ObservationRegistry; - -import java.io.IOException; -import java.io.StringReader; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - /** * @author Jemin Huh * @author Soby Chacko @@ -67,8 +67,6 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem public static final String COSINE_SIMILARITY_FUNCTION = "cosinesimil"; - private static final Logger logger = LoggerFactory.getLogger(OpenSearchVectorStore.class); - public static final String DEFAULT_INDEX_NAME = "spring-ai-document-index"; public static final String DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536 = """ @@ -82,6 +80,8 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem } """; + private static final Logger logger = LoggerFactory.getLogger(OpenSearchVectorStore.class); + private final EmbeddingModel embeddingModel; private final OpenSearchClient openSearchClient; @@ -92,12 +92,12 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem private final String mappingJson; - private String similarityFunction; - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; + private String similarityFunction; + public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(openSearchClient, embeddingModel, DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536, @@ -245,7 +245,7 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem } private CreateIndexResponse createIndexMapping(String index, String mappingJson) { - JsonpMapper jsonpMapper = openSearchClient._transport().jsonpMapper(); + JsonpMapper jsonpMapper = this.openSearchClient._transport().jsonpMapper(); try { return this.openSearchClient.indices() .create(new CreateIndexRequest.Builder().index(index) @@ -285,4 +285,4 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem return this.similarityFunction; } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java index e830fa545..77e2a95a0 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverterTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Date; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; @@ -25,38 +34,31 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import java.util.Date; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.filter.Filter; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; - class OpenSearchAiSearchFilterExpressionConverterTest { final FilterExpressionConverter converter = new OpenSearchAiSearchFilterExpressionConverter(); @Test public void testDate() { - String vectorExpr = converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value(new Date(1704637752148L)))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:2024-01-07T14:29:12Z"); - vectorExpr = converter.convertExpression( + vectorExpr = this.converter.convertExpression( new Filter.Expression(EQ, new Filter.Key("activationDate"), new Filter.Value("1970-01-01T00:00:02Z"))); assertThat(vectorExpr).isEqualTo("metadata.activationDate:1970-01-01T00:00:02Z"); } @Test public void testEQ() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country:BG"); } @Test public void tesEqAndGte() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("genre"), new Filter.Value("drama")), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)))); assertThat(vectorExpr).isEqualTo("metadata.genre:drama AND metadata.year:>=2020"); @@ -64,14 +66,14 @@ class OpenSearchAiSearchFilterExpressionConverterTest { @Test public void tesIn() { - String vectorExpr = converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), + String vectorExpr = this.converter.convertExpression(new Filter.Expression(IN, new Filter.Key("genre"), new Filter.Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("(metadata.genre:comedy OR documentary OR drama)"); } @Test public void testNe() { - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")), @@ -81,7 +83,7 @@ class OpenSearchAiSearchFilterExpressionConverterTest { @Test public void testGroup() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Group(new Filter.Expression(OR, new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020)), new Filter.Expression(EQ, new Filter.Key("country"), new Filter.Value("BG")))), @@ -92,7 +94,7 @@ class OpenSearchAiSearchFilterExpressionConverterTest { @Test public void tesBoolean() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(AND, new Filter.Expression(EQ, new Filter.Key("isOpen"), new Filter.Value(true)), new Filter.Expression(GTE, new Filter.Key("year"), new Filter.Value(2020))), new Filter.Expression(IN, new Filter.Key("country"), new Filter.Value(List.of("BG", "NL", "US"))))); @@ -103,7 +105,7 @@ class OpenSearchAiSearchFilterExpressionConverterTest { @Test public void testDecimal() { - String vectorExpr = converter.convertExpression(new Filter.Expression(AND, + String vectorExpr = this.converter.convertExpression(new Filter.Expression(AND, new Filter.Expression(GTE, new Filter.Key("temperature"), new Filter.Value(-15.6)), new Filter.Expression(LTE, new Filter.Key("temperature"), new Filter.Value(20.13)))); @@ -112,11 +114,11 @@ class OpenSearchAiSearchFilterExpressionConverterTest { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("\"country 1 2 3\""), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); - vectorExpr = converter + vectorExpr = this.converter .convertExpression(new Filter.Expression(EQ, new Filter.Key("'country 1 2 3'"), new Filter.Value("BG"))); assertThat(vectorExpr).isEqualTo("metadata.country 1 2 3:BG"); } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java index dea664624..294ed87d5 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index 645bebd4a..a38207fb4 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,17 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + import org.apache.hc.core5.http.HttpHost; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; @@ -27,6 +38,9 @@ import org.junit.jupiter.params.provider.ValueSource; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -38,19 +52,6 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.ZonedDateTime; -import java.util.Date; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -126,7 +127,7 @@ class OpenSearchVectorStoreIT { vectorStore.withSimilarityFunction(similarityFunction); } - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await() .until(() -> vectorStore @@ -138,14 +139,14 @@ class OpenSearchVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore @@ -245,7 +246,7 @@ class OpenSearchVectorStoreIT { assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(1)), hasSize(0)); @@ -318,7 +319,7 @@ class OpenSearchVectorStoreIT { vectorStore.withSimilarityFunction(similarityFunction); } - vectorStore.add(documents); + vectorStore.add(this.documents); SearchRequest query = SearchRequest.query("Great Depression") .withTopK(50) @@ -339,13 +340,13 @@ class OpenSearchVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java index 45298605d..7ce5101a1 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.net.URISyntaxException; @@ -25,6 +24,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.apache.hc.core5.http.HttpHost; import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; @@ -34,6 +36,9 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -51,13 +56,8 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; /** @@ -87,10 +87,6 @@ public class OpenSearchVectorStoreObservationIT { } } - private ApplicationContextRunner getContextRunner() { - return new ApplicationContextRunner().withUserConfiguration(Config.class); - } - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -98,6 +94,10 @@ public class OpenSearchVectorStoreObservationIT { Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); } + private ApplicationContextRunner getContextRunner() { + return new ApplicationContextRunner().withUserConfiguration(Config.class); + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { @@ -115,7 +115,7 @@ public class OpenSearchVectorStoreObservationIT { TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -182,7 +182,7 @@ public class OpenSearchVectorStoreObservationIT { observationRegistry.clear(); - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); Awaitility.await() .until(() -> vectorStore diff --git a/vector-stores/spring-ai-oracle-store/pom.xml b/vector-stores/spring-ai-oracle-store/pom.xml index b95d5ee01..7335b4d88 100644 --- a/vector-stores/spring-ai-oracle-store/pom.xml +++ b/vector-stores/spring-ai-oracle-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index 290f570f9..a32a904eb 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,9 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; -import static org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue; - import java.io.ByteArrayOutputStream; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -31,8 +28,16 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import io.micrometer.observation.ObservationRegistry; +import oracle.jdbc.OracleType; +import oracle.sql.VECTOR; +import oracle.sql.json.OracleJsonFactory; +import oracle.sql.json.OracleJsonGenerator; +import oracle.sql.json.OracleJsonObject; +import oracle.sql.json.OracleJsonValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -51,13 +56,8 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; import org.springframework.util.StringUtils; -import io.micrometer.observation.ObservationRegistry; -import oracle.jdbc.OracleType; -import oracle.sql.VECTOR; -import oracle.sql.json.OracleJsonFactory; -import oracle.sql.json.OracleJsonGenerator; -import oracle.sql.json.OracleJsonObject; -import oracle.sql.json.OracleJsonValue; +import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; +import static org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue; /** *

    @@ -86,9 +86,461 @@ import oracle.sql.json.OracleJsonValue; */ public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final double SIMILARITY_THRESHOLD_EXACT_MATCH = 1.0d; + + public static final String DEFAULT_TABLE_NAME = "SPRING_AI_VECTORS"; + + public static final OracleVectorStoreIndexType DEFAULT_INDEX_TYPE = OracleVectorStoreIndexType.IVF; + + public static final OracleVectorStoreDistanceType DEFAULT_DISTANCE_TYPE = OracleVectorStoreDistanceType.COSINE; + + public static final int DEFAULT_DIMENSIONS = -1; + + public static final int DEFAULT_SEARCH_ACCURACY = -1; + private static final Logger logger = LoggerFactory.getLogger(OracleVectorStore.class); - public static final double SIMILARITY_THRESHOLD_EXACT_MATCH = 1.0d; + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, + OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, + OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); + + public final FilterExpressionConverter filterExpressionConverter = new SqlJsonPathFilterExpressionConverter(); + + private final JdbcTemplate jdbcTemplate; + + private final EmbeddingModel embeddingModel; + + private final boolean initializeSchema; + + private final boolean removeExistingVectorStoreTable; + + /** + * Table name where vectors will be stored. + */ + private final String tableName; + + /** + * Index type used to index the vectors. It can impact performance and database memory + * consumption. + */ + private final OracleVectorStoreIndexType indexType; + + /** + * Distance type to use for computing vector distances. + */ + private final OracleVectorStoreDistanceType distanceType; + + /** + * Expected number of dimensions for vectors. Enforcing vector dimensions is very + * useful to ensure future vector distance computations will be relevant. + */ + private final int dimensions; + + private final boolean forcedNormalization; + + private final int searchAccuracy; + + private final BatchingStrategy batchingStrategy; + + private final OracleJsonFactory osonFactory = new OracleJsonFactory(); + + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, + DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, false, false, false); + } + + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, boolean initializeSchema) { + this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, + DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, initializeSchema, false, false); + } + + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, + OracleVectorStoreIndexType indexType, OracleVectorStoreDistanceType distanceType, int dimensions, + int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, + boolean forcedNormalization) { + this(jdbcTemplate, embeddingModel, tableName, indexType, distanceType, dimensions, searchAccuracy, + initializeSchema, removeExistingVectorStoreTable, forcedNormalization, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); + } + + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, + OracleVectorStoreIndexType indexType, OracleVectorStoreDistanceType distanceType, int dimensions, + int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, + boolean forcedNormalization, ObservationRegistry observationRegistry, + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { + + super(observationRegistry, customObservationConvention); + + if (dimensions != DEFAULT_DIMENSIONS) { + if (dimensions <= 0) { + throw new RuntimeException("Number of dimensions must be strictly positive"); + } + if (dimensions > 65535) { + throw new RuntimeException("Number of dimensions must be at most 65535"); + } + } + + if (searchAccuracy != DEFAULT_SEARCH_ACCURACY) { + if (searchAccuracy < 1) { + throw new RuntimeException("Search accuracy must be greater or equals to 1"); + } + if (searchAccuracy > 100) { + throw new RuntimeException("Search accuracy must be lower or equals to 100"); + } + } + + this.jdbcTemplate = jdbcTemplate; + this.embeddingModel = embeddingModel; + this.tableName = tableName; + this.indexType = indexType; + this.distanceType = distanceType; + this.dimensions = dimensions; + this.searchAccuracy = searchAccuracy; + this.initializeSchema = initializeSchema; + this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; + this.forcedNormalization = forcedNormalization; + this.batchingStrategy = batchingStrategy; + } + + @Override + public void doAdd(final List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { + + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + final Document document = documents.get(i); + final String content = document.getContent(); + final byte[] json = toJson(document.getMetadata()); + final VECTOR embeddingVector = toVECTOR(document.getEmbedding()); + + setParameterValue(ps, 1, Types.VARCHAR, document.getId()); + setParameterValue(ps, 2, Types.VARCHAR, content); + setParameterValue(ps, 3, OracleType.JSON.getVendorTypeNumber(), json); + setParameterValue(ps, 4, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); + } + + @Override + public int getBatchSize() { + return documents.size(); + } + }); + } + + private String getIngestStatement() { + return String + .format(""" + merge into %s target using (values(?, ?, ?, ?)) source (id, content, metadata, embedding) on (target.id = source.id) + when matched then update set target.content = source.content, target.metadata = source.metadata, target.embedding = source.embedding + when not matched then insert (target.id, target.content, target.metadata, target.embedding) values (source.id, source.content, source.metadata, source.embedding)""", + this.tableName); + } + + /** + * Bind binary JSON from the client. + * @param m map of metadata + * @return the binary JSON ready to be inserted + */ + private byte[] toJson(final Map m) { + this.out.reset(); + try (OracleJsonGenerator gen = this.osonFactory.createJsonBinaryGenerator(this.out)) { + gen.writeStartObject(); + for (String key : m.keySet()) { + final Object o = m.get(key); + if (o instanceof String) { + gen.write(key, (String) o); + } + else if (o instanceof Integer) { + gen.write(key, (Integer) o); + } + else if (o instanceof Float) { + gen.write(key, (Float) o); + } + else if (o instanceof Double) { + gen.write(key, (Double) o); + } + else if (o instanceof Boolean) { + gen.write(key, (Boolean) o); + } + } + gen.writeEnd(); + } + + return this.out.toByteArray(); + } + + /** + * Converts a list of Double values into an Oracle VECTOR object ready to be inserted. + * Optionally normalize the vector beforehand (see forcedNormalization). + * @param floatList + * @return + * @throws SQLException + */ + private VECTOR toVECTOR(final float[] floatList) throws SQLException { + final double[] doubles = new double[floatList.length]; + int i = 0; + for (double d : floatList) { + doubles[i++] = d; + } + + if (this.forcedNormalization) { + return VECTOR.ofFloat64Values(normalize(doubles)); + } + + return VECTOR.ofFloat64Values(doubles); + } + + /** + * Normalize a vector if requested. + * @param v vector to normalize + * @return the vector normalized + */ + private double[] normalize(final double[] v) { + double squaredSum = 0d; + + for (double e : v) { + squaredSum += e * e; + } + + final double magnitude = Math.sqrt(squaredSum); + + if (magnitude > 0) { + final double multiplier = 1d / magnitude; + final int length = v.length; + for (int i = 0; i < length; i++) { + v[i] *= multiplier; + } + } + + return v; + } + + @Override + public Optional doDelete(final List idList) { + final String sql = String.format("delete from %s where id=?", this.tableName); + final int[] argTypes = { Types.VARCHAR }; + + final List batchArgs = new ArrayList<>(); + for (String id : idList) { + batchArgs.add(new Object[] { id }); + } + + final int[] deleteCounts = this.jdbcTemplate.batchUpdate(sql, batchArgs, argTypes); + + int deleteCount = 0; + for (int detailedResult : deleteCounts) { + switch (detailedResult) { + case Statement.EXECUTE_FAILED: + break; + case 1: + case Statement.SUCCESS_NO_INFO: + deleteCount++; + break; + } + } + + return Optional.of(deleteCount == idList.size()); + } + + @Override + public List doSimilaritySearch(SearchRequest request) { + try { + // From the provided query, generate a vector using the embedding model + final VECTOR embeddingVector = toVECTOR(this.embeddingModel.embed(request.getQuery())); + + if (logger.isDebugEnabled()) { + this.jdbcTemplate.batchUpdate("insert into debug(embedding) values(?)", + new BatchPreparedStatementSetter() { + + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + setParameterValue(ps, 1, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); + } + + @Override + public int getBatchSize() { + return 1; + } + }); + } + + final String nativeFilterExpression = (request.getFilterExpression() != null) + ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; + + String jsonPathFilter = ""; + + if (request.getSimilarityThreshold() == SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) { + if (StringUtils.hasText(nativeFilterExpression)) { + jsonPathFilter = String.format("where JSON_EXISTS( metadata, '%s' )\n", nativeFilterExpression); + } + + final String sql = this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" + select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance + from %s + %sorder by distance + fetch first %d rows only""", this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), + this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance + from %s + %sorder by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), + this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, + request.getTopK(), this.searchAccuracy); + + logger.debug("SQL query: " + sql); + + return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); + } + else if (request.getSimilarityThreshold() == SIMILARITY_THRESHOLD_EXACT_MATCH) { + if (StringUtils.hasText(nativeFilterExpression)) { + jsonPathFilter = String.format("where JSON_EXISTS( metadata, '%s' )\n", nativeFilterExpression); + } + + final String sql = String.format(""" + select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance + from %s + %sorder by distance + fetch EXACT first %d rows only""", this.distanceType == DOT ? "(1+" : "", + this.distanceType.name(), this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, + request.getTopK()); + + logger.debug("SQL query: " + sql); + + return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); + } + else { + if (!this.forcedNormalization + || (this.distanceType != OracleVectorStoreDistanceType.COSINE && this.distanceType != DOT)) { + throw new RuntimeException( + "Similarity threshold filtering requires all vectors to be normalized, see the forcedNormalization parameter for this Vector store. Also only COSINE and DOT distance types are supported."); + } + + final double distance = this.distanceType == DOT ? (1d - request.getSimilarityThreshold()) * 2d - 1d + : 1d - request.getSimilarityThreshold(); + + if (StringUtils.hasText(nativeFilterExpression)) { + jsonPathFilter = String.format(" and JSON_EXISTS( metadata, '%s' )", nativeFilterExpression); + } + + final String sql = this.distanceType == DOT ? (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY + ? String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch first %d rows only""", + this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy) + + ) : (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" + select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance + from %s + where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s + order by distance + fetch first %d rows only""", this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance + from %s + where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy)); + + logger.debug("SQL query: " + sql); + + return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector, embeddingVector, + distance); + } + } + catch (SQLException sqle) { + throw new RuntimeException(sqle); + } + } + + @Override + public void afterPropertiesSet() throws Exception { + if (this.initializeSchema) { + // Remove existing VectorStoreTable + if (this.removeExistingVectorStoreTable) { + this.jdbcTemplate.execute(String.format("drop table if exists %s purge", this.tableName)); + } + + this.jdbcTemplate.execute(String.format(""" + create table if not exists %s ( + id varchar2(36) default sys_guid() primary key, + content clob not null, + metadata json not null, + embedding vector(%s,FLOAT64) annotations(Distance '%s', IndexType '%s') + )""", this.tableName, this.dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(this.dimensions), + this.distanceType.name(), this.indexType.name())); + + if (logger.isDebugEnabled()) { + this.jdbcTemplate.execute(String.format(""" + create table if not exists debug ( + id varchar2(36) default sys_guid() primary key, + embedding vector(%s,FLOAT64) annotations(Distance '%s') + )""", this.dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(this.dimensions), + this.distanceType.name())); + } + + switch (this.indexType) { + case IVF: + this.jdbcTemplate.execute(String.format(""" + create vector index if not exists vector_index_%s on %s (embedding) + organization neighbor partitions + distance %s + with target accuracy %d + parameters (type IVF, neighbor partitions 10)""", this.tableName, + this.tableName, this.distanceType.name(), + this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : this.searchAccuracy)); + break; + + /* + * TODO: Enable for 23.5 case HNSW: + * this.jdbcTemplate.execute(String.format(""" create vector index if not + * exists vector_index_%s on %s (embedding) organization inmemory neighbor + * graph distance %s with target accuracy %d parameters (type HNSW, + * neighbors 40, efconstruction 500)""", tableName, tableName, + * distanceType.name(), searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : + * searchAccuracy)); break; + */ + } + } + } + + public String getTableName() { + return this.tableName; + } + + @Override + public Builder createObservationContextBuilder(String operationName) { + return VectorStoreObservationContext.builder(VectorStoreProvider.ORACLE.value(), operationName) + .withDimensions(this.embeddingModel.dimensions()) + .withCollectionName(this.getTableName()) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.distanceType)) { + return this.distanceType.name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); + } public enum OracleVectorStoreIndexType { @@ -174,255 +626,6 @@ public class OracleVectorStore extends AbstractObservationVectorStore implements } - public static final String DEFAULT_TABLE_NAME = "SPRING_AI_VECTORS"; - - public static final OracleVectorStoreIndexType DEFAULT_INDEX_TYPE = OracleVectorStoreIndexType.IVF; - - public static final OracleVectorStoreDistanceType DEFAULT_DISTANCE_TYPE = OracleVectorStoreDistanceType.COSINE; - - public static final int DEFAULT_DIMENSIONS = -1; - - public static final int DEFAULT_SEARCH_ACCURACY = -1; - - private final JdbcTemplate jdbcTemplate; - - private final EmbeddingModel embeddingModel; - - private final boolean initializeSchema; - - private final boolean removeExistingVectorStoreTable; - - public final FilterExpressionConverter filterExpressionConverter = new SqlJsonPathFilterExpressionConverter(); - - /** - * Table name where vectors will be stored. - */ - private final String tableName; - - /** - * Index type used to index the vectors. It can impact performance and database memory - * consumption. - */ - private final OracleVectorStoreIndexType indexType; - - /** - * Distance type to use for computing vector distances. - */ - private final OracleVectorStoreDistanceType distanceType; - - /** - * Expected number of dimensions for vectors. Enforcing vector dimensions is very - * useful to ensure future vector distance computations will be relevant. - */ - private final int dimensions; - - private final boolean forcedNormalization; - - private final int searchAccuracy; - - private final BatchingStrategy batchingStrategy; - - public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, - DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, false, false, false); - } - - public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, boolean initializeSchema) { - this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, - DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, initializeSchema, false, false); - } - - public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, - OracleVectorStoreIndexType indexType, OracleVectorStoreDistanceType distanceType, int dimensions, - int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, - boolean forcedNormalization) { - this(jdbcTemplate, embeddingModel, tableName, indexType, distanceType, dimensions, searchAccuracy, - initializeSchema, removeExistingVectorStoreTable, forcedNormalization, ObservationRegistry.NOOP, null, - new TokenCountBatchingStrategy()); - } - - public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, - OracleVectorStoreIndexType indexType, OracleVectorStoreDistanceType distanceType, int dimensions, - int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, - boolean forcedNormalization, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - - super(observationRegistry, customObservationConvention); - - if (dimensions != DEFAULT_DIMENSIONS) { - if (dimensions <= 0) { - throw new RuntimeException("Number of dimensions must be strictly positive"); - } - if (dimensions > 65535) { - throw new RuntimeException("Number of dimensions must be at most 65535"); - } - } - - if (searchAccuracy != DEFAULT_SEARCH_ACCURACY) { - if (searchAccuracy < 1) { - throw new RuntimeException("Search accuracy must be greater or equals to 1"); - } - if (searchAccuracy > 100) { - throw new RuntimeException("Search accuracy must be lower or equals to 100"); - } - } - - this.jdbcTemplate = jdbcTemplate; - this.embeddingModel = embeddingModel; - this.tableName = tableName; - this.indexType = indexType; - this.distanceType = distanceType; - this.dimensions = dimensions; - this.searchAccuracy = searchAccuracy; - this.initializeSchema = initializeSchema; - this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; - this.forcedNormalization = forcedNormalization; - this.batchingStrategy = batchingStrategy; - } - - @Override - public void doAdd(final List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - final Document document = documents.get(i); - final String content = document.getContent(); - final byte[] json = toJson(document.getMetadata()); - final VECTOR embeddingVector = toVECTOR(document.getEmbedding()); - - setParameterValue(ps, 1, Types.VARCHAR, document.getId()); - setParameterValue(ps, 2, Types.VARCHAR, content); - setParameterValue(ps, 3, OracleType.JSON.getVendorTypeNumber(), json); - setParameterValue(ps, 4, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); - } - - @Override - public int getBatchSize() { - return documents.size(); - } - }); - } - - private String getIngestStatement() { - return String - .format(""" - merge into %s target using (values(?, ?, ?, ?)) source (id, content, metadata, embedding) on (target.id = source.id) - when matched then update set target.content = source.content, target.metadata = source.metadata, target.embedding = source.embedding - when not matched then insert (target.id, target.content, target.metadata, target.embedding) values (source.id, source.content, source.metadata, source.embedding)""", - tableName); - } - - private final OracleJsonFactory osonFactory = new OracleJsonFactory(); - - private final ByteArrayOutputStream out = new ByteArrayOutputStream(); - - /** - * Bind binary JSON from the client. - * @param m map of metadata - * @return the binary JSON ready to be inserted - */ - private byte[] toJson(final Map m) { - out.reset(); - try (OracleJsonGenerator gen = osonFactory.createJsonBinaryGenerator(out)) { - gen.writeStartObject(); - for (String key : m.keySet()) { - final Object o = m.get(key); - if (o instanceof String) { - gen.write(key, (String) o); - } - else if (o instanceof Integer) { - gen.write(key, (Integer) o); - } - else if (o instanceof Float) { - gen.write(key, (Float) o); - } - else if (o instanceof Double) { - gen.write(key, (Double) o); - } - else if (o instanceof Boolean) { - gen.write(key, (Boolean) o); - } - } - gen.writeEnd(); - } - - return out.toByteArray(); - } - - /** - * Converts a list of Double values into an Oracle VECTOR object ready to be inserted. - * Optionally normalize the vector beforehand (see forcedNormalization). - * @param floatList - * @return - * @throws SQLException - */ - private VECTOR toVECTOR(final float[] floatList) throws SQLException { - final double[] doubles = new double[floatList.length]; - int i = 0; - for (double d : floatList) { - doubles[i++] = d; - } - - if (forcedNormalization) { - return VECTOR.ofFloat64Values(normalize(doubles)); - } - - return VECTOR.ofFloat64Values(doubles); - } - - /** - * Normalize a vector if requested. - * @param v vector to normalize - * @return the vector normalized - */ - private double[] normalize(final double[] v) { - double squaredSum = 0d; - - for (double e : v) { - squaredSum += e * e; - } - - final double magnitude = Math.sqrt(squaredSum); - - if (magnitude > 0) { - final double multiplier = 1d / magnitude; - final int length = v.length; - for (int i = 0; i < length; i++) { - v[i] *= multiplier; - } - } - - return v; - } - - @Override - public Optional doDelete(final List idList) { - final String sql = String.format("delete from %s where id=?", tableName); - final int[] argTypes = { Types.VARCHAR }; - - final List batchArgs = new ArrayList<>(); - for (String id : idList) { - batchArgs.add(new Object[] { id }); - } - - final int[] deleteCounts = jdbcTemplate.batchUpdate(sql, batchArgs, argTypes); - - int deleteCount = 0; - for (int detailedResult : deleteCounts) { - switch (detailedResult) { - case Statement.EXECUTE_FAILED: - break; - case 1: - case Statement.SUCCESS_NO_INFO: - deleteCount++; - break; - } - } - - return Optional.of(deleteCount == idList.size()); - } - private static class DocumentRowMapper implements RowMapper { @Override @@ -459,195 +662,4 @@ public class OracleVectorStore extends AbstractObservationVectorStore implements } - @Override - public List doSimilaritySearch(SearchRequest request) { - try { - // From the provided query, generate a vector using the embedding model - final VECTOR embeddingVector = toVECTOR(embeddingModel.embed(request.getQuery())); - - if (logger.isDebugEnabled()) { - this.jdbcTemplate.batchUpdate("insert into debug(embedding) values(?)", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - setParameterValue(ps, 1, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); - } - - @Override - public int getBatchSize() { - return 1; - } - }); - } - - final String nativeFilterExpression = (request.getFilterExpression() != null) - ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; - - String jsonPathFilter = ""; - - if (request.getSimilarityThreshold() == SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) { - if (StringUtils.hasText(nativeFilterExpression)) { - jsonPathFilter = String.format("where JSON_EXISTS( metadata, '%s' )\n", nativeFilterExpression); - } - - final String sql = searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" - select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance - from %s - %sorder by distance - fetch first %d rows only""", distanceType == DOT ? "(1+" : "", distanceType.name(), - distanceType == DOT ? ")/2" : "", tableName, jsonPathFilter, request.getTopK()) - : String.format( - """ - select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance - from %s - %sorder by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", - distanceType == DOT ? "(1+" : "", distanceType.name(), distanceType == DOT ? ")/2" : "", - tableName, jsonPathFilter, request.getTopK(), searchAccuracy); - - logger.debug("SQL query: " + sql); - - return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); - } - else if (request.getSimilarityThreshold() == SIMILARITY_THRESHOLD_EXACT_MATCH) { - if (StringUtils.hasText(nativeFilterExpression)) { - jsonPathFilter = String.format("where JSON_EXISTS( metadata, '%s' )\n", nativeFilterExpression); - } - - final String sql = String.format(""" - select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance - from %s - %sorder by distance - fetch EXACT first %d rows only""", distanceType == DOT ? "(1+" : "", distanceType.name(), - distanceType == DOT ? ")/2" : "", tableName, jsonPathFilter, request.getTopK()); - - logger.debug("SQL query: " + sql); - - return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); - } - else { - if (!forcedNormalization - || (distanceType != OracleVectorStoreDistanceType.COSINE && distanceType != DOT)) { - throw new RuntimeException( - "Similarity threshold filtering requires all vectors to be normalized, see the forcedNormalization parameter for this Vector store. Also only COSINE and DOT distance types are supported."); - } - - final double distance = distanceType == DOT ? (1d - request.getSimilarityThreshold()) * 2d - 1d - : 1d - request.getSimilarityThreshold(); - - if (StringUtils.hasText(nativeFilterExpression)) { - jsonPathFilter = String.format(" and JSON_EXISTS( metadata, '%s' )", nativeFilterExpression); - } - - final String sql = distanceType == DOT ? (searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch first %d rows only""", tableName, jsonPathFilter, request.getTopK()) : String.format(""" - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", tableName, jsonPathFilter, - request.getTopK(), searchAccuracy) - - ) : (searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" - select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance - from %s - where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s - order by distance - fetch first %d rows only""", tableName, jsonPathFilter, request.getTopK()) : String.format(""" - select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance - from %s - where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", tableName, jsonPathFilter, - request.getTopK(), searchAccuracy)); - - logger.debug("SQL query: " + sql); - - return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector, embeddingVector, - distance); - } - } - catch (SQLException sqle) { - throw new RuntimeException(sqle); - } - } - - @Override - public void afterPropertiesSet() throws Exception { - if (this.initializeSchema) { - // Remove existing VectorStoreTable - if (this.removeExistingVectorStoreTable) { - this.jdbcTemplate.execute(String.format("drop table if exists %s purge", tableName)); - } - - this.jdbcTemplate.execute(String.format(""" - create table if not exists %s ( - id varchar2(36) default sys_guid() primary key, - content clob not null, - metadata json not null, - embedding vector(%s,FLOAT64) annotations(Distance '%s', IndexType '%s') - )""", tableName, dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(dimensions), - distanceType.name(), indexType.name())); - - if (logger.isDebugEnabled()) { - this.jdbcTemplate.execute(String.format(""" - create table if not exists debug ( - id varchar2(36) default sys_guid() primary key, - embedding vector(%s,FLOAT64) annotations(Distance '%s') - )""", dimensions == DEFAULT_DIMENSIONS ? "*" : String.valueOf(dimensions), - distanceType.name())); - } - - switch (indexType) { - case IVF: - this.jdbcTemplate.execute(String.format(""" - create vector index if not exists vector_index_%s on %s (embedding) - organization neighbor partitions - distance %s - with target accuracy %d - parameters (type IVF, neighbor partitions 10)""", tableName, tableName, - distanceType.name(), searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : searchAccuracy)); - break; - - /* - * TODO: Enable for 23.5 case HNSW: - * this.jdbcTemplate.execute(String.format(""" create vector index if not - * exists vector_index_%s on %s (embedding) organization inmemory neighbor - * graph distance %s with target accuracy %d parameters (type HNSW, - * neighbors 40, efconstruction 500)""", tableName, tableName, - * distanceType.name(), searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : - * searchAccuracy)); break; - */ - } - } - } - - public String getTableName() { - return tableName; - } - - @Override - public Builder createObservationContextBuilder(String operationName) { - return VectorStoreObservationContext.builder(VectorStoreProvider.ORACLE.value(), operationName) - .withDimensions(this.embeddingModel.dimensions()) - .withCollectionName(this.getTableName()) - .withSimilarityMetric(getSimilarityMetric()); - } - - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - OracleVectorStoreDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, - OracleVectorStoreDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN, - OracleVectorStoreDistanceType.DOT, VectorStoreSimilarityMetric.DOT); - - private String getSimilarityMetric() { - if (!SIMILARITY_TYPE_MAPPING.containsKey(this.distanceType)) { - return this.distanceType.name(); - } - return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); - } - } diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java index ad3432b7c..0fd446257 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverter.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java index 3f2250a3c..6955a6606 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java index b78790ecd..638db84b5 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java @@ -1,10 +1,41 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import javax.sql.DataSource; + import oracle.jdbc.pool.OracleDataSource; import org.junit.Assert; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -22,19 +53,6 @@ import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.CollectionUtils; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; - -import javax.sql.DataSource; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY; @@ -51,15 +69,6 @@ public class OracleVectorStoreIT { new Document(getText("classpath:/test/data/time.shelter.txt")), new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); - public static String getText(final String uri) { - try { - return new DefaultResourceLoader().getResource(uri).getContentAsString(StandardCharsets.UTF_8); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestClient.class) .withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=COSINE", @@ -70,52 +79,13 @@ public class OracleVectorStoreIT { String.format("app.datasource.password=%s", oracle23aiContainer.getPassword()), "app.datasource.type=oracle.jdbc.pool.OracleDataSource"); - @SpringBootConfiguration - @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) - public static class TestClient { - - @Value("${test.spring.ai.vectorstore.oracle.distanceType}") - OracleVectorStore.OracleVectorStoreDistanceType distanceType; - - @Value("${test.spring.ai.vectorstore.oracle.searchAccuracy}") - int searchAccuracy; - - @Bean - public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return new OracleVectorStore(jdbcTemplate, embeddingModel, OracleVectorStore.DEFAULT_TABLE_NAME, - OracleVectorStore.OracleVectorStoreIndexType.IVF, distanceType, 384, searchAccuracy, true, true, - true); + public static String getText(final String uri) { + try { + return new DefaultResourceLoader().getResource(uri).getContentAsString(StandardCharsets.UTF_8); } - - @Bean - public JdbcTemplate myJdbcTemplate(DataSource dataSource) { - return new JdbcTemplate(dataSource); + catch (IOException e) { + throw new RuntimeException(e); } - - @Bean - @Primary - @ConfigurationProperties("app.datasource") - public DataSourceProperties dataSourceProperties() { - return new DataSourceProperties(); - } - - @Bean - public OracleDataSource dataSource(DataSourceProperties dataSourceProperties) { - return dataSourceProperties.initializeDataSourceBuilder().type(OracleDataSource.class).build(); - } - - @Bean - public EmbeddingModel embeddingModel() { - try { - TransformersEmbeddingModel tem = new TransformersEmbeddingModel(); - tem.afterPropertiesSet(); - return tem; - } - catch (Exception e) { - throw new RuntimeException("Failed initializing embedding model", e); - } - } - } private static void dropTable(ApplicationContext context, String tableName) { @@ -123,27 +93,49 @@ public class OracleVectorStoreIT { jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE"); } + private static boolean isSortedByDistance(final List documents) { + final List distances = documents.stream() + .map(doc -> (Double) doc.getMetadata().get("distance")) + .toList(); + + if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + return true; + } + + Iterator iter = distances.iterator(); + Double current; + Double previous = iter.next(); + while (iter.hasNext()) { + current = iter.next(); + if (previous > current) { + return false; + } + previous = current; + } + return true; + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void addAndSearch(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -157,7 +149,7 @@ public class OracleVectorStoreIT { @CsvSource({ "COSINE,-1", "DOT,-1", "EUCLIDEAN,-1", "EUCLIDEAN_SQUARED,-1", "MANHATTAN,-1", "COSINE,75", "DOT,80", "EUCLIDEAN,60", "EUCLIDEAN_SQUARED,30", "MANHATTAN,42" }) public void searchWithFilters(String distanceType, int searchAccuracy) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + searchAccuracy) .run(context -> { @@ -231,7 +223,7 @@ public class OracleVectorStoreIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void documentUpdate(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -270,13 +262,13 @@ public class OracleVectorStoreIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE", "DOT" }) public void searchWithThreshold(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll()); @@ -296,32 +288,58 @@ public class OracleVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(1).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); } - private static boolean isSortedByDistance(final List documents) { - final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestClient { - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { - return true; + @Value("${test.spring.ai.vectorstore.oracle.distanceType}") + OracleVectorStore.OracleVectorStoreDistanceType distanceType; + + @Value("${test.spring.ai.vectorstore.oracle.searchAccuracy}") + int searchAccuracy; + + @Bean + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new OracleVectorStore(jdbcTemplate, embeddingModel, OracleVectorStore.DEFAULT_TABLE_NAME, + OracleVectorStore.OracleVectorStoreIndexType.IVF, this.distanceType, 384, this.searchAccuracy, true, + true, true); } - Iterator iter = distances.iterator(); - Double current; - Double previous = iter.next(); - while (iter.hasNext()) { - current = iter.next(); - if (previous > current) { - return false; + @Bean + public JdbcTemplate myJdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public OracleDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(OracleDataSource.class).build(); + } + + @Bean + public EmbeddingModel embeddingModel() { + try { + TransformersEmbeddingModel tem = new TransformersEmbeddingModel(); + tem.afterPropertiesSet(); + return tem; + } + catch (Exception e) { + throw new RuntimeException("Failed initializing embedding model", e); } - previous = current; } - return true; + } } diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java index 8d2636fbf..50a496e44 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,16 @@ import java.util.Map; import javax.sql.DataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import oracle.jdbc.pool.OracleDataSource; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.oracle.OracleContainer; +import org.testcontainers.utility.MountableFile; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -47,15 +55,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.oracle.OracleContainer; -import org.testcontainers.utility.MountableFile; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import oracle.jdbc.pool.OracleDataSource; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -100,13 +101,13 @@ public class OracleVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java index 23f35e2c9..165e0a853 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/SqlJsonPathFilterExpressionConverterTests.java @@ -1,11 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; +import static org.assertj.core.api.Assertions.assertThat; + public class SqlJsonPathFilterExpressionConverterTests { @Test diff --git a/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql b/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql index ac38a1965..0b42b6ff7 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql +++ b/vector-stores/spring-ai-oracle-store/src/test/resources/initialize.sql @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + -- Exit on any errors WHENEVER SQLERROR EXIT SQL.SQLCODE diff --git a/vector-stores/spring-ai-pgvector-store/pom.xml b/vector-stores/spring-ai-pgvector-store/pom.xml index 2cd142626..7f3f53416 100644 --- a/vector-stores/spring-ai-pgvector-store/pom.xml +++ b/vector-stores/spring-ai-pgvector-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java index 18d8e23bc..06db63670 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.List; + import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; -import java.util.List; /** * Converts {@link Expression} into PgVector metadata filter expression format. diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java index f8017d97b..5e7c30d39 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorSchemaValidator.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; @@ -64,7 +66,7 @@ public class PgVectorSchemaValidator { String sql = "SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?"; try { // Query for a single integer value, if it exists, table exists - jdbcTemplate.queryForObject(sql, Integer.class, schemaName, tableName); + this.jdbcTemplate.queryForObject(sql, Integer.class, schemaName, tableName); return true; } catch (DataAccessException e) { @@ -100,7 +102,7 @@ public class PgVectorSchemaValidator { // Include the schema name in the query to target the correct table String query = "SELECT column_name, data_type FROM information_schema.columns " + "WHERE table_schema = ? AND table_name = ?"; - List> columns = jdbcTemplate.queryForList(query, + List> columns = this.jdbcTemplate.queryForList(query, new Object[] { schemaName, tableName }); if (columns.isEmpty()) { diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 486902fb8..56bb866e6 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; @@ -23,6 +33,7 @@ import io.micrometer.observation.ObservationRegistry; import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -44,15 +55,6 @@ import org.springframework.jdbc.core.StatementCreatorUtils; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - /** * Uses the "vector_store" table to store the Spring AI vector data. The table and the * vector index will be auto-created if not available. @@ -67,8 +69,6 @@ import java.util.UUID; */ public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class); - public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536; public static final int INVALID_EMBEDDING_DIMENSION = -1; @@ -81,10 +81,17 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public static final boolean DEFAULT_SCHEMA_VALIDATION = false; - public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); - public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class); + + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + PgDistanceType.COSINE_DISTANCE, VectorStoreSimilarityMetric.COSINE, PgDistanceType.EUCLIDEAN_DISTANCE, + VectorStoreSimilarityMetric.EUCLIDEAN, PgDistanceType.NEGATIVE_INNER_PRODUCT, + VectorStoreSimilarityMetric.DOT); + + public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + private final String vectorTableName; private final String vectorIndexName; @@ -183,7 +190,7 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini } public PgDistanceType getDistanceType() { - return distanceType; + return this.distanceType; } @Override @@ -208,6 +215,7 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { @@ -247,7 +255,7 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public Optional doDelete(List idList) { int updateCount = 0; for (String id : idList) { - int count = jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", + int count = this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", UUID.fromString(id)); updateCount = updateCount + count; } @@ -281,6 +289,7 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini return this.jdbcTemplate.query( "SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM " + getFullyQualifiedTableName(), new RowMapper() { + @Override @Nullable public Double mapRow(ResultSet rs, int rowNum) throws SQLException { @@ -383,6 +392,23 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini return OPENAI_EMBEDDING_DIMENSION_SIZE; } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.PG_VECTOR.value(), operationName) + .withCollectionName(this.vectorTableName) + .withDimensions(this.embeddingDimensions()) + .withNamespace(this.schemaName) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.getDistanceType())) { + return this.getDistanceType().name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); + } + /** * By default, pgvector performs exact nearest neighbor search, which provides perfect * recall. You can add an index to use approximate nearest neighbor search, which @@ -492,7 +518,7 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini String source = pgObject.getValue(); try { - return (Map) objectMapper.readValue(source, Map.class); + return (Map) this.objectMapper.readValue(source, Map.class); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -611,26 +637,4 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini } - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - - return VectorStoreObservationContext.builder(VectorStoreProvider.PG_VECTOR.value(), operationName) - .withCollectionName(this.vectorTableName) - .withDimensions(this.embeddingDimensions()) - .withNamespace(this.schemaName) - .withSimilarityMetric(getSimilarityMetric()); - } - - private static Map SIMILARITY_TYPE_MAPPING = Map.of( - PgDistanceType.COSINE_DISTANCE, VectorStoreSimilarityMetric.COSINE, PgDistanceType.EUCLIDEAN_DISTANCE, - VectorStoreSimilarityMetric.EUCLIDEAN, PgDistanceType.NEGATIVE_INNER_PRODUCT, - VectorStoreSimilarityMetric.DOT); - - private String getSimilarityMetric() { - if (!SIMILARITY_TYPE_MAPPING.containsKey(this.getDistanceType())) { - return this.getDistanceType().name(); - } - return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java index 4fa2c56a5..efef6d913 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.junit.jupiter.api.Test; @@ -46,32 +47,32 @@ public class PgVectorEmbeddingDimensionsTests { final int explicitDimensions = 696; - var dim = new PgVectorStore(jdbcTemplate, embeddingModel, explicitDimensions).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel, explicitDimensions).embeddingDimensions(); assertThat(dim).isEqualTo(explicitDimensions); - verify(embeddingModel, never()).dimensions(); + verify(this.embeddingModel, never()).dimensions(); } @Test public void embeddingModelDimensions() { - when(embeddingModel.dimensions()).thenReturn(969); + when(this.embeddingModel.dimensions()).thenReturn(969); - var dim = new PgVectorStore(jdbcTemplate, embeddingModel).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); assertThat(dim).isEqualTo(969); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } @Test public void fallBackToDefaultDimensions() { - when(embeddingModel.dimensions()).thenThrow(new RuntimeException()); + when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); - var dim = new PgVectorStore(jdbcTemplate, embeddingModel).embeddingDimensions(); + var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); assertThat(dim).isEqualTo(PgVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE); - verify(embeddingModel, only()).dimensions(); + verify(this.embeddingModel, only()).dimensions(); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java index ca662ab8e..b2ef770cc 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; -import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.Test; + import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + +import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; @@ -28,10 +35,6 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LT import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.Filter.Value; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; /** * @author Muthukumaran Navaneethakrishnan @@ -44,14 +47,14 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testEQ() { // country == "BG" - String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("$.country == \"BG\""); } @Test public void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("$.genre == \"drama\" && $.year >= 2020"); @@ -60,7 +63,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter.convertExpression( + String vectorExpr = this.converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr) .isEqualTo("($.genre == \"comedy\" || $.genre == \"documentary\" || $.genre == \"drama\")"); @@ -69,7 +72,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia"))))); @@ -79,7 +82,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); @@ -90,7 +93,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter.convertExpression(new Expression(AND, + String vectorExpr = this.converter.convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); @@ -102,7 +105,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -111,7 +114,7 @@ public class PgVectorFilterExpressionConverterTests { @Test public void testComplexIdentifiers() { - String vectorExpr = converter + String vectorExpr = this.converter .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("$.\"country 1 2 3\" == \"BG\""); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java index 5e4204cdb..0df031e63 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java index 5d34e98c6..626085dba 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreCustomNamesIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Random; + +import javax.sql.DataSource; + import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; @@ -32,12 +41,6 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import javax.sql.DataSource; -import java.util.Random; import static org.assertj.core.api.Assertions.assertThat; @@ -92,19 +95,20 @@ public class PgVectorStoreCustomNamesIT { @Test public void shouldCreateDefaultTableAndIndexIfNotPresentInConfig() { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.schemaValidation=false").run(context -> { - assertThat(context).hasNotFailed(); - assertThat(isTableExists(context, "vector_store")).isTrue(); - assertThat(isSchemaExists(context, "public")).isTrue(); - dropTableByName(context, "vector_store"); + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.schemaValidation=false") + .run(context -> { + assertThat(context).hasNotFailed(); + assertThat(isTableExists(context, "vector_store")).isTrue(); + assertThat(isSchemaExists(context, "public")).isTrue(); + dropTableByName(context, "vector_store"); - }); + }); } @Test public void shouldCreateTableAndIndexIfNotPresentInDatabase() { String tableName = "new_vector_table"; - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName) .run(context -> { assertThat(isTableExists(context, tableName)).isTrue(); assertThat(isIndexExists(context, "public", tableName, tableName + "_index")).isTrue(); @@ -118,7 +122,7 @@ public class PgVectorStoreCustomNamesIT { String tableName = "customvectortable"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -136,7 +140,7 @@ public class PgVectorStoreCustomNamesIT { String tableName = "users; DROP TABLE users;"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -156,7 +160,7 @@ public class PgVectorStoreCustomNamesIT { String schemaName = "public; DROP TABLE users;"; String tableName = "customvectortable"; - contextRunner + this.contextRunner .withPropertyValues("test.spring.ai.vectorstore.pgvector.vectorTableName=" + tableName, "test.spring.ai.vectorstore.pgvector.schemaName=" + schemaName, "test.spring.ai.vectorstore.pgvector.schemaValidation=true") @@ -189,10 +193,10 @@ public class PgVectorStoreCustomNamesIT { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { - return new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(schemaName) - .withVectorTableName(vectorTableName) - .withVectorTableValidationsEnabled(schemaValidation) - .withDimensions(dimensions) + return new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(this.schemaName) + .withVectorTableName(this.vectorTableName) + .withVectorTableValidationsEnabled(this.schemaValidation) + .withDimensions(this.dimensions) .withDistanceType(PgVectorStore.PgDistanceType.COSINE_DISTANCE) .withRemoveExistingVectorStoreTable(true) .withIndexType(PgIndexType.HNSW) diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index f18a1669f..8405d3233 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -28,12 +27,17 @@ import java.util.stream.Stream; import javax.sql.DataSource; +import com.zaxxer.hikari.HikariDataSource; import org.junit.Assert; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -53,11 +57,8 @@ import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.CollectionUtils; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.zaxxer.hikari.HikariDataSource; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Muthukumaran Navaneethakrishnan @@ -74,6 +75,16 @@ public class PgVectorStoreIT { .withUsername("postgres") .withPassword("postgres"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -89,50 +100,11 @@ public class PgVectorStoreIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - - // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", - "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - private static void dropTable(ApplicationContext context) { JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store"); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) - public void addAndSearch(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) - .run(context -> { - - VectorStore vectorStore = context.getBean(VectorStore.class); - - vectorStore.add(documents); - - List results = vectorStore - .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); - - assertThat(results).hasSize(1); - Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); - - // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - - List results2 = vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - assertThat(results2).hasSize(0); - - dropTable(context); - }); - } - static Stream provideFilters() { return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In Arguments.of("year in [2020]", 1), // Numeric Filters In @@ -141,11 +113,60 @@ public class PgVectorStoreIT { ); } + private static boolean isSortedByDistance(List docs) { + + List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + + if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + return true; + } + + Iterator iter = distances.iterator(); + Float current, previous = iter.next(); + while (iter.hasNext()) { + current = iter.next(); + if (previous > current) { + return false; + } + previous = current; + } + return true; + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) + public void addAndSearch(String distanceType) { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(this.documents); + + List results = vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + + List results2 = vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results2).hasSize(0); + + dropTable(context); + }); + } + @ParameterizedTest(name = "Filter expression {0} should return {1} records ") @MethodSource("provideFilters") public void searchWithInFilter(String expression, Integer expectedRecords) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE") + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE") .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -177,7 +198,7 @@ public class PgVectorStoreIT { @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void searchWithFilters(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -251,7 +272,7 @@ public class PgVectorStoreIT { @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void documentUpdate(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -292,12 +313,12 @@ public class PgVectorStoreIT { // @ValueSource(strings = { "COSINE_DISTANCE" }) public void searchWithThreshold(String distanceType) { - contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + distanceType) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll()); @@ -317,32 +338,12 @@ public class PgVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(1).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); dropTable(context); }); } - private static boolean isSortedByDistance(List docs) { - - List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); - - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { - return true; - } - - Iterator iter = distances.iterator(); - Float current, previous = iter.next(); - while (iter.hasNext()) { - current = iter.next(); - if (previous > current) { - return false; - } - previous = current; - } - return true; - } - @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { @@ -353,7 +354,7 @@ public class PgVectorStoreIT { @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { return new PgVectorStore(jdbcTemplate, embeddingModel, PgVectorStore.INVALID_EMBEDDING_DIMENSION, - distanceType, true, PgIndexType.HNSW, true); + this.distanceType, true, PgIndexType.HNSW, true); } @Bean diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java index f60b66b49..963b256e9 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,16 @@ import java.util.Map; import javax.sql.DataSource; +import com.zaxxer.hikari.HikariDataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -48,15 +55,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.zaxxer.hikari.HikariDataSource; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * Integration tests for observation instruAbstractObservationVectorStorementation in @@ -75,6 +75,16 @@ public class PgVectorStoreObservationIT { .withUsername("postgres") .withPassword("postgres"); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", + + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -90,26 +100,16 @@ public class PgVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class) - .withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE", - - // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", - "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index 488dbd3f7..993f8856e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Collections; + import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -29,13 +37,6 @@ import static org.mockito.Mockito.only; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.util.Collections; - -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.jdbc.core.BatchPreparedStatementSetter; -import org.springframework.jdbc.core.JdbcTemplate; - /** * @author Muthukumaran Navaneethakrishnan * @author Soby Chacko diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java index 5c151753d..abde63cfe 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,12 +16,6 @@ package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import java.util.List; import java.util.Map; @@ -32,6 +26,10 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; import org.springframework.ai.chat.messages.AssistantMessage; @@ -43,9 +41,12 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * @author Fabian Krüger @@ -55,14 +56,68 @@ import org.testcontainers.junit.jupiter.Testcontainers; @Testcontainers class PgVectorStoreWithChatMemoryAdvisorIT { - float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; - @Container @SuppressWarnings("resource") static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) .withUsername("postgres") .withPassword("postgres"); + float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; + + private static @NotNull ChatModel chatModelAlwaysReturnsTheSameReply() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" + Why don't scientists trust atoms? + Because they make up everything! + """)))); + when(chatModel.call(argumentCaptor.capture())).thenReturn(chatResponse); + return chatModel; + } + + private static void initStore(PgVectorStore store) throws Exception { + store.afterPropertiesSet(); + // fill the store + store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")), + new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER")))); + } + + private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception { + JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer(); + PgVectorStore vectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withDimensions(3) // match + // embeddings + .withInitializeSchema(true) + .build(); + initStore(vectorStore); + return vectorStore; + } + + private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() { + PGSimpleDataSource ds = new PGSimpleDataSource(); + ds.setUrl("jdbc:postgresql://localhost:" + postgresContainer.getMappedPort(5432) + "/postgres"); + ds.setUser(postgresContainer.getUsername()); + ds.setPassword(postgresContainer.getPassword()); + return new JdbcTemplate(ds); + } + + private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) { + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + verify(chatModel).call(promptCaptor.capture()); + assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(promptCaptor.getValue().getInstructions().get(0).getContent()).isEqualTo(""" + + + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. + + --------------------- + LONG_TERM_MEMORY: + Tell me a good joke + Tell me a bad joke + --------------------- + + """); + } + /** * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar * messages from the (gp)vector store. @@ -88,42 +143,6 @@ class PgVectorStoreWithChatMemoryAdvisorIT { verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); } - private static @NotNull ChatModel chatModelAlwaysReturnsTheSameReply() { - ChatModel chatModel = mock(ChatModel.class); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); - ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" - Why don't scientists trust atoms? - Because they make up everything! - """)))); - when(chatModel.call(argumentCaptor.capture())).thenReturn(chatResponse); - return chatModel; - } - - private static void initStore(PgVectorStore store) throws Exception { - store.afterPropertiesSet(); - // fill the store - store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")), - new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER")))); - } - - private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception { - JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer(); - PgVectorStore vectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withDimensions(3) // match - // embeddings - .withInitializeSchema(true) - .build(); - initStore(vectorStore); - return vectorStore; - } - - private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() { - PGSimpleDataSource ds = new PGSimpleDataSource(); - ds.setUrl("jdbc:postgresql://localhost:" + postgresContainer.getMappedPort(5432) + "/postgres"); - ds.setUser(postgresContainer.getUsername()); - ds.setPassword(postgresContainer.getPassword()); - return new JdbcTemplate(ds); - } - @SuppressWarnings("unchecked") private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() { EmbeddingModel embeddingModel = mock(EmbeddingModel.class); @@ -131,29 +150,11 @@ class PgVectorStoreWithChatMemoryAdvisorIT { Mockito.doAnswer(invocationOnMock -> { Object[] arguments = invocationOnMock.getArguments(); List documents = (List) arguments[0]; - documents.forEach(d -> d.setEmbedding(embed)); - return List.of(embed, embed); + documents.forEach(d -> d.setEmbedding(this.embed)); + return List.of(this.embed, this.embed); }).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any()); - when(embeddingModel.embed(any(String.class))).thenReturn(embed); + when(embeddingModel.embed(any(String.class))).thenReturn(this.embed); return embeddingModel; } - private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) { - ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); - verify(chatModel).call(promptCaptor.capture()); - assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); - assertThat(promptCaptor.getValue().getInstructions().get(0).getContent()).isEqualTo(""" - - - Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. - - --------------------- - LONG_TERM_MEMORY: - Tell me a good joke - Tell me a bad joke - --------------------- - - """); - } - -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pinecone-store/pom.xml b/vector-stores/spring-ai-pinecone-store/pom.xml index b1c883453..87b2722c5 100644 --- a/vector-stores/spring-ai-pinecone-store/pom.xml +++ b/vector-stores/spring-ai-pinecone-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index 59a7f7ff8..f4243653b 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -21,6 +21,22 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import io.micrometer.observation.ObservationRegistry; +import io.pinecone.PineconeClient; +import io.pinecone.PineconeClientConfig; +import io.pinecone.PineconeConnection; +import io.pinecone.PineconeConnectionConfig; +import io.pinecone.proto.DeleteRequest; +import io.pinecone.proto.QueryRequest; +import io.pinecone.proto.QueryResponse; +import io.pinecone.proto.UpsertRequest; +import io.pinecone.proto.Vector; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -36,23 +52,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.protobuf.Struct; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; - -import io.micrometer.observation.ObservationRegistry; -import io.pinecone.PineconeClient; -import io.pinecone.PineconeClientConfig; -import io.pinecone.PineconeConnection; -import io.pinecone.PineconeConnectionConfig; -import io.pinecone.proto.DeleteRequest; -import io.pinecone.proto.QueryRequest; -import io.pinecone.proto.QueryResponse; -import io.pinecone.proto.UpsertRequest; -import io.pinecone.proto.Vector; - /** * A VectorStore implementation backed by Pinecone, a cloud-based vector database. This * store supports creating, updating, deleting, and similarity searching of documents in a @@ -86,180 +85,6 @@ public class PineconeVectorStore extends AbstractObservationVectorStore { private final BatchingStrategy batchingStrategy; - /** - * Configuration class for the PineconeVectorStore. - */ - public static final class PineconeVectorStoreConfig { - - // The free tier (gcp-starter) doesn't support Namespaces. - // Leave the namespace empty (e.g. "") for the free tier. - private final String namespace; - - private final String contentFieldName; - - private final String distanceMetadataFieldName; - - private final PineconeConnectionConfig connectionConfig; - - private final PineconeClientConfig clientConfig; - - // private final int defaultSimilarityTopK; - - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - public PineconeVectorStoreConfig(Builder builder) { - this.namespace = builder.namespace; - this.contentFieldName = builder.contentFieldName; - this.distanceMetadataFieldName = builder.distanceMetadataFieldName; - - // this.defaultSimilarityTopK = builder.defaultSimilarityTopK; - this.connectionConfig = new PineconeConnectionConfig().withIndexName(builder.indexName); - this.clientConfig = new PineconeClientConfig().withApiKey(builder.apiKey) - .withEnvironment(builder.environment) - .withProjectName(builder.projectId) - .withApiKey(builder.apiKey) - .withServerSideTimeoutSec((int) builder.serverSideTimeout.toSeconds()); - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static PineconeVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String apiKey; - - private String projectId; - - private String environment; - - private String indexName; - - // The free-tier (gcp-starter) doesn't support Namespaces! - private String namespace = ""; - - private String contentFieldName = CONTENT_FIELD_NAME; - - private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; - - /** - * Optional server-side timeout in seconds for all operations. Default: 20 - * seconds. - */ - private Duration serverSideTimeout = Duration.ofSeconds(20); - - private Builder() { - } - - /** - * Pinecone api key. - * @param apiKey key to use. - * @return this builder. - */ - public Builder withApiKey(String apiKey) { - this.apiKey = apiKey; - return this; - } - - /** - * Pinecone project id. - * @param projectId Project id to use. - * @return this builder. - */ - public Builder withProjectId(String projectId) { - this.projectId = projectId; - return this; - } - - /** - * Pinecone environment name. - * @param environment Environment name (e.g. gcp-starter). - * @return this builder. - */ - public Builder withEnvironment(String environment) { - this.environment = environment; - return this; - } - - /** - * Pinecone index name. - * @param indexName Pinecone index name to use. - * @return this builder. - */ - public Builder withIndexName(String indexName) { - this.indexName = indexName; - return this; - } - - /** - * Pinecone Namespace. The free-tier (gcp-starter) doesn't support Namespaces. - * For free-tier leave the namespace empty. - * @param namespace Pinecone namespace to use. - * @return this builder. - */ - public Builder withNamespace(String namespace) { - this.namespace = namespace; - return this; - } - - /** - * Content field name. - * @param contentFieldName content field name to use. - * @return this builder. - */ - public Builder withContentFieldName(String contentFieldName) { - this.contentFieldName = contentFieldName; - return this; - } - - /** - * Distance metadata field name. - * @param distanceMetadataFieldName distance metadata field name to use. - * @return this builder. - */ - public Builder withDistanceMetadataFieldName(String distanceMetadataFieldName) { - this.distanceMetadataFieldName = distanceMetadataFieldName; - return this; - } - - /** - * Pinecone server side timeout. - * @param serverSideTimeout server timeout to use. - * @return this builder. - */ - public Builder withServerSideTimeout(Duration serverSideTimeout) { - this.serverSideTimeout = serverSideTimeout; - return this; - } - - /** - * {@return the immutable configuration} - */ - public PineconeVectorStoreConfig build() { - return new PineconeVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new PineconeVectorStore. * @param config The configuration for the store. @@ -442,6 +267,7 @@ public class PineconeVectorStore extends AbstractObservationVectorStore { try { String json = JsonFormat.printer().print(metadataStruct); Map metadata = this.objectMapper.readValue(json, new TypeReference>() { + }); metadata.remove(this.pineconeContentFieldName); return metadata; @@ -461,4 +287,178 @@ public class PineconeVectorStore extends AbstractObservationVectorStore { .withFieldName(this.pineconeContentFieldName); } + /** + * Configuration class for the PineconeVectorStore. + */ + public static final class PineconeVectorStoreConfig { + + // The free tier (gcp-starter) doesn't support Namespaces. + // Leave the namespace empty (e.g. "") for the free tier. + private final String namespace; + + private final String contentFieldName; + + private final String distanceMetadataFieldName; + + private final PineconeConnectionConfig connectionConfig; + + private final PineconeClientConfig clientConfig; + + // private final int defaultSimilarityTopK; + + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + public PineconeVectorStoreConfig(Builder builder) { + this.namespace = builder.namespace; + this.contentFieldName = builder.contentFieldName; + this.distanceMetadataFieldName = builder.distanceMetadataFieldName; + + // this.defaultSimilarityTopK = builder.defaultSimilarityTopK; + this.connectionConfig = new PineconeConnectionConfig().withIndexName(builder.indexName); + this.clientConfig = new PineconeClientConfig().withApiKey(builder.apiKey) + .withEnvironment(builder.environment) + .withProjectName(builder.projectId) + .withApiKey(builder.apiKey) + .withServerSideTimeoutSec((int) builder.serverSideTimeout.toSeconds()); + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static PineconeVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String apiKey; + + private String projectId; + + private String environment; + + private String indexName; + + // The free-tier (gcp-starter) doesn't support Namespaces! + private String namespace = ""; + + private String contentFieldName = CONTENT_FIELD_NAME; + + private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; + + /** + * Optional server-side timeout in seconds for all operations. Default: 20 + * seconds. + */ + private Duration serverSideTimeout = Duration.ofSeconds(20); + + private Builder() { + } + + /** + * Pinecone api key. + * @param apiKey key to use. + * @return this builder. + */ + public Builder withApiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Pinecone project id. + * @param projectId Project id to use. + * @return this builder. + */ + public Builder withProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + /** + * Pinecone environment name. + * @param environment Environment name (e.g. gcp-starter). + * @return this builder. + */ + public Builder withEnvironment(String environment) { + this.environment = environment; + return this; + } + + /** + * Pinecone index name. + * @param indexName Pinecone index name to use. + * @return this builder. + */ + public Builder withIndexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Pinecone Namespace. The free-tier (gcp-starter) doesn't support Namespaces. + * For free-tier leave the namespace empty. + * @param namespace Pinecone namespace to use. + * @return this builder. + */ + public Builder withNamespace(String namespace) { + this.namespace = namespace; + return this; + } + + /** + * Content field name. + * @param contentFieldName content field name to use. + * @return this builder. + */ + public Builder withContentFieldName(String contentFieldName) { + this.contentFieldName = contentFieldName; + return this; + } + + /** + * Distance metadata field name. + * @param distanceMetadataFieldName distance metadata field name to use. + * @return this builder. + */ + public Builder withDistanceMetadataFieldName(String distanceMetadataFieldName) { + this.distanceMetadataFieldName = distanceMetadataFieldName; + return this; + } + + /** + * Pinecone server side timeout. + * @param serverSideTimeout server timeout to use. + * @return this builder. + */ + public Builder withServerSideTimeout(Duration serverSideTimeout) { + this.serverSideTimeout = serverSideTimeout; + return this; + } + + /** + * {@return the immutable configuration} + */ + public PineconeVectorStoreConfig build() { + return new PineconeVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java index c9b63fbb4..20c0dd75d 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStoreHints.java @@ -1,11 +1,27 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; +import java.util.Set; + import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; -import java.util.Set; - /** * Registration of AOT hints for Pinecone's vector store. * diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index a0f8a257c..917abaf47 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.io.IOException; @@ -62,6 +63,9 @@ public class PineconeVectorStoreIT { private static final String CUSTOM_CONTENT_FIELD_NAME = "article"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -77,9 +81,6 @@ public class PineconeVectorStoreIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -90,11 +91,11 @@ public class PineconeVectorStoreIT { @Test public void addAndSearchTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -104,14 +105,14 @@ public class PineconeVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); @@ -125,7 +126,7 @@ public class PineconeVectorStoreIT { // Pinecone metadata filtering syntax: // https://docs.pinecone.io/docs/metadata-filtering - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -177,7 +178,7 @@ public class PineconeVectorStoreIT { public void documentUpdateTest() { // Note ,using OpenAI to calculate embeddings - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -234,11 +235,11 @@ public class PineconeVectorStoreIT { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Awaitility.await().until(() -> { return vectorStore @@ -259,13 +260,13 @@ public class PineconeVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); }, hasSize(0)); @@ -301,4 +302,4 @@ public class PineconeVectorStoreIT { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java index 8fc6b41f7..b8a84de02 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.Matchers.hasSize; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,11 +22,15 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.awaitility.Awaitility; import org.awaitility.Duration; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -45,9 +47,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; /** * @author Christian Tzolov @@ -67,6 +68,9 @@ public class PineconeVectorStoreObservationIT { private static final String CUSTOM_CONTENT_FIELD_NAME = "article"; + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -82,9 +86,6 @@ public class PineconeVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll public static void beforeAll() { Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); @@ -95,13 +96,13 @@ public class PineconeVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -165,7 +166,7 @@ public class PineconeVectorStoreObservationIT { .hasBeenStopped(); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await().until(() -> { return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); diff --git a/vector-stores/spring-ai-qdrant-store/pom.xml b/vector-stores/spring-ai-qdrant-store/pom.xml index f88c4f84e..2d9c598e5 100644 --- a/vector-stores/spring-ai-qdrant-store/pom.xml +++ b/vector-stores/spring-ai-qdrant-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java index ec1d1fbdd..1870a37aa 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; +import java.util.ArrayList; +import java.util.List; + +import io.qdrant.client.grpc.Points.Condition; +import io.qdrant.client.grpc.Points.Filter; +import io.qdrant.client.grpc.Points.Range; + +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Operand; +import org.springframework.ai.vectorstore.filter.Filter.Value; + import static io.qdrant.client.ConditionFactory.filter; import static io.qdrant.client.ConditionFactory.match; import static io.qdrant.client.ConditionFactory.matchExceptKeywords; @@ -24,20 +39,6 @@ import static io.qdrant.client.ConditionFactory.matchKeywords; import static io.qdrant.client.ConditionFactory.matchValues; import static io.qdrant.client.ConditionFactory.range; -import java.util.ArrayList; -import java.util.List; - -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; -import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.Filter.Operand; -import org.springframework.ai.vectorstore.filter.Filter.Value; - -import io.qdrant.client.grpc.Points.Condition; -import io.qdrant.client.grpc.Points.Filter; -import io.qdrant.client.grpc.Points.Range; - /** * @author Anush Shetty * @since 0.8.1 diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java index 8f86cb872..00ad1e518 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.util.Map; diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java index 87c1067dd..13862abc0 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.lang.reflect.Array; diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 158caf903..8fadc237a 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,32 +16,12 @@ package org.springframework.ai.vectorstore.qdrant; -import static io.qdrant.client.PointIdFactory.id; -import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vectors; -import static io.qdrant.client.WithPayloadSelectorFactory.enable; - import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ExecutionException; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.util.Assert; - import io.micrometer.observation.ObservationRegistry; import io.qdrant.client.QdrantClient; import io.qdrant.client.grpc.Collections.Distance; @@ -54,6 +34,25 @@ import io.qdrant.client.grpc.Points.ScoredPoint; import io.qdrant.client.grpc.Points.SearchPoints; import io.qdrant.client.grpc.Points.UpdateStatus; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; + +import static io.qdrant.client.PointIdFactory.id; +import static io.qdrant.client.ValueFactory.value; +import static io.qdrant.client.VectorsFactory.vectors; +import static io.qdrant.client.WithPayloadSelectorFactory.enable; + /** * Qdrant vectorStore implementation. This store supports creating, updating, deleting, * and similarity searching of documents in a Qdrant collection. @@ -67,12 +66,12 @@ import io.qdrant.client.grpc.Points.UpdateStatus; */ public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean { + public static final String DEFAULT_COLLECTION_NAME = "vector_store"; + private static final String CONTENT_FIELD_NAME = "doc_content"; private static final String DISTANCE_FIELD_NAME = "distance"; - public static final String DEFAULT_COLLECTION_NAME = "vector_store"; - private final EmbeddingModel embeddingModel; private final QdrantClient qdrantClient; @@ -85,68 +84,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements private final BatchingStrategy batchingStrategy; - /** - * Configuration class for the QdrantVectorStore. - * - * @deprecated since 1.0.0 in favor of {@link QdrantVectorStore}. - */ - @Deprecated(since = "1.0.0", forRemoval = true) - public static final class QdrantVectorStoreConfig { - - private final String collectionName; - - /* - * Constructor using the builder. - * - * @param builder The configuration builder. - */ - - private QdrantVectorStoreConfig(Builder builder) { - this.collectionName = builder.collectionName; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static QdrantVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String collectionName; - - private Builder() { - } - - /** - * @param collectionName REQUIRED. The name of the collection. - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * {@return the immutable configuration} - */ - public QdrantVectorStoreConfig build() { - Assert.notNull(collectionName, "collectionName cannot be null"); - return new QdrantVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new QdrantVectorStore. * @param config The configuration for the store. @@ -319,8 +256,9 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements @Override public void afterPropertiesSet() throws Exception { - if (!this.initializeSchema) + if (!this.initializeSchema) { return; + } // Create the collection if it does not exist. if (!isCollectionExists()) { @@ -350,4 +288,66 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements } -} \ No newline at end of file + /** + * Configuration class for the QdrantVectorStore. + * + * @deprecated since 1.0.0 in favor of {@link QdrantVectorStore}. + */ + @Deprecated(since = "1.0.0", forRemoval = true) + public static final class QdrantVectorStoreConfig { + + private final String collectionName; + + /* + * Constructor using the builder. + * + * @param builder The configuration builder. + */ + + private QdrantVectorStoreConfig(Builder builder) { + this.collectionName = builder.collectionName; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static QdrantVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String collectionName; + + private Builder() { + } + + /** + * @param collectionName REQUIRED. The name of the collection. + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * {@return the immutable configuration} + */ + public QdrantVectorStoreConfig build() { + Assert.notNull(this.collectionName, "collectionName cannot be null"); + return new QdrantVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java index 4a6592ffa..2045be309 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 7cd00a4ef..b5e30f945 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import java.util.Collections; @@ -28,14 +29,14 @@ import io.qdrant.client.grpc.Collections.VectorParams; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.mistralai.MistralAiEmbeddingModel; -import org.springframework.ai.mistralai.api.MistralAiApi; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.qdrant.QdrantContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.mistralai.MistralAiEmbeddingModel; +import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.boot.SpringBootConfiguration; @@ -62,6 +63,10 @@ public class QdrantVectorStoreIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer(QdrantImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); + List documents = List.of( new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")), @@ -70,10 +75,6 @@ public class QdrantVectorStoreIT { "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression", Collections.singletonMap("meta2", "meta2"))); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(TestApplication.class) - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); - @BeforeAll static void setup() throws InterruptedException, ExecutionException { @@ -91,23 +92,23 @@ public class QdrantVectorStoreIT { @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); List results2 = vectorStore.similaritySearch(SearchRequest.query("Great").withTopK(1)); assertThat(results2).hasSize(0); @@ -117,7 +118,7 @@ public class QdrantVectorStoreIT { @Test public void addAndSearchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -166,7 +167,7 @@ public class QdrantVectorStoreIT { @Test public void documentUpdateTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -206,11 +207,11 @@ public class QdrantVectorStoreIT { @Test public void searchThresholdTest() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); @@ -225,14 +226,14 @@ public class QdrantVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); }); } diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java index 5d181cf03..42031c3c5 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore.qdrant; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore.qdrant; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -23,9 +22,20 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.qdrant.client.QdrantClient; +import io.qdrant.client.QdrantGrpcClient; +import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.Collections.VectorParams; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.qdrant.QdrantContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -43,17 +53,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.qdrant.QdrantContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.qdrant.client.QdrantClient; -import io.qdrant.client.QdrantGrpcClient; -import io.qdrant.client.grpc.Collections.Distance; -import io.qdrant.client.grpc.Collections.VectorParams; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -70,6 +71,9 @@ public class QdrantVectorStoreObservationIT { @Container static QdrantContainer qdrantContainer = new QdrantContainer(QdrantImage.DEFAULT_IMAGE); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -85,9 +89,6 @@ public class QdrantVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @BeforeAll static void setup() throws InterruptedException, ExecutionException { @@ -106,13 +107,13 @@ public class QdrantVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index 6d8a0ca9a..cf0c0e4d2 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java index 95198ff3b..86638c071 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.text.MessageFormat; @@ -177,6 +178,7 @@ public class RedisFilterExpressionConverter extends AbstractFilterExpressionConv } static record Numeric(NumericBoundary lower, NumericBoundary upper) { + } static record NumericBoundary(Object value, boolean exclusive) { @@ -197,11 +199,11 @@ public class RedisFilterExpressionConverter extends AbstractFilterExpressionConv if (this == POSITIVE_INFINITY) { return INFINITY; } - return String.format(formatString(), value); + return String.format(formatString(), this.value); } private String formatString() { - if (exclusive) { + if (this.exclusive) { return EXCLUSIVE_FORMAT; } return INCLUSIVE_FORMAT; diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index d6c7c6808..3a850f1dc 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -27,24 +27,9 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; -import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; - -import io.micrometer.observation.ObservationRegistry; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; @@ -61,6 +46,21 @@ import redis.clients.jedis.search.schemafields.TextField; import redis.clients.jedis.search.schemafields.VectorField; import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + /** * The RedisVectorStore is for managing and querying vector data in a Redis database. It * offers functionalities like adding, deleting, and performing similarity searches on @@ -87,165 +87,6 @@ import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; */ public class RedisVectorStore extends AbstractObservationVectorStore implements InitializingBean { - public enum Algorithm { - - FLAT, HSNW - - } - - public record MetadataField(String name, FieldType fieldType) { - - public static MetadataField text(String name) { - return new MetadataField(name, FieldType.TEXT); - } - - public static MetadataField numeric(String name) { - return new MetadataField(name, FieldType.NUMERIC); - } - - public static MetadataField tag(String name) { - return new MetadataField(name, FieldType.TAG); - } - - } - - /** - * Configuration for the Redis vector store. - */ - public static final class RedisVectorStoreConfig { - - private final String indexName; - - private final String prefix; - - private final String contentFieldName; - - private final String embeddingFieldName; - - private final Algorithm vectorAlgorithm; - - private final List metadataFields; - - private RedisVectorStoreConfig() { - this(builder()); - } - - private RedisVectorStoreConfig(Builder builder) { - this.indexName = builder.indexName; - this.prefix = builder.prefix; - this.contentFieldName = builder.contentFieldName; - this.embeddingFieldName = builder.embeddingFieldName; - this.vectorAlgorithm = builder.vectorAlgorithm; - this.metadataFields = builder.metadataFields; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - /** - * {@return the default config} - */ - public static RedisVectorStoreConfig defaultConfig() { - - return builder().build(); - } - - public static class Builder { - - private String indexName = DEFAULT_INDEX_NAME; - - private String prefix = DEFAULT_PREFIX; - - private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; - - private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; - - private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; - - private List metadataFields = new ArrayList<>(); - - private Builder() { - } - - /** - * Configures the Redis index name to use. - * @param name the index name to use - * @return this builder - */ - public Builder withIndexName(String name) { - this.indexName = name; - return this; - } - - /** - * Configures the Redis key prefix to use (default: "embedding:"). - * @param prefix the prefix to use - * @return this builder - */ - public Builder withPrefix(String prefix) { - this.prefix = prefix; - return this; - } - - /** - * Configures the Redis content field name to use. - * @param name the content field name to use - * @return this builder - */ - public Builder withContentFieldName(String name) { - this.contentFieldName = name; - return this; - } - - /** - * Configures the Redis embedding field name to use. - * @param name the embedding field name to use - * @return this builder - */ - public Builder withEmbeddingFieldName(String name) { - this.embeddingFieldName = name; - return this; - } - - /** - * Configures the Redis vector algorithmto use. - * @param algorithm the vector algorithm to use - * @return this builder - */ - public Builder withVectorAlgorithm(Algorithm algorithm) { - this.vectorAlgorithm = algorithm; - return this; - } - - public Builder withMetadataFields(MetadataField... fields) { - return withMetadataFields(Arrays.asList(fields)); - } - - public Builder withMetadataFields(List fields) { - this.metadataFields = fields; - return this; - } - - /** - * {@return the immutable configuration} - */ - public RedisVectorStoreConfig build() { - - return new RedisVectorStoreConfig(this); - } - - } - - } - - private final boolean initializeSchema; - public static final String DEFAULT_INDEX_NAME = "spring-ai-index"; public static final String DEFAULT_CONTENT_FIELD_NAME = "content"; @@ -256,6 +97,8 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW; + public static final String DISTANCE_FIELD_NAME = "vector_score"; + private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]"; private static final Path2 JSON_SET_PATH = Path2.of("$"); @@ -272,20 +115,20 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements private static final String EMBEDDING_PARAM_NAME = "BLOB"; - public static final String DISTANCE_FIELD_NAME = "vector_score"; - private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; + private final boolean initializeSchema; + private final JedisPooled jedis; private final EmbeddingModel embeddingModel; private final RedisVectorStoreConfig config; - private FilterExpressionConverter filterExpressionConverter; - private final BatchingStrategy batchingStrategy; + private FilterExpressionConverter filterExpressionConverter; + public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) { @@ -475,7 +318,7 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements } private VectorAlgorithm vectorAlgorithm() { - if (config.vectorAlgorithm == Algorithm.HSNW) { + if (this.config.vectorAlgorithm == Algorithm.HSNW) { return VectorAlgorithm.HNSW; } return VectorAlgorithm.FLAT; @@ -497,4 +340,161 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements } -} \ No newline at end of file + public enum Algorithm { + + FLAT, HSNW + + } + + public record MetadataField(String name, FieldType fieldType) { + + public static MetadataField text(String name) { + return new MetadataField(name, FieldType.TEXT); + } + + public static MetadataField numeric(String name) { + return new MetadataField(name, FieldType.NUMERIC); + } + + public static MetadataField tag(String name) { + return new MetadataField(name, FieldType.TAG); + } + + } + + /** + * Configuration for the Redis vector store. + */ + public static final class RedisVectorStoreConfig { + + private final String indexName; + + private final String prefix; + + private final String contentFieldName; + + private final String embeddingFieldName; + + private final Algorithm vectorAlgorithm; + + private final List metadataFields; + + private RedisVectorStoreConfig() { + this(builder()); + } + + private RedisVectorStoreConfig(Builder builder) { + this.indexName = builder.indexName; + this.prefix = builder.prefix; + this.contentFieldName = builder.contentFieldName; + this.embeddingFieldName = builder.embeddingFieldName; + this.vectorAlgorithm = builder.vectorAlgorithm; + this.metadataFields = builder.metadataFields; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + /** + * {@return the default config} + */ + public static RedisVectorStoreConfig defaultConfig() { + + return builder().build(); + } + + public static class Builder { + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; + + private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; + + private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; + + private List metadataFields = new ArrayList<>(); + + private Builder() { + } + + /** + * Configures the Redis index name to use. + * @param name the index name to use + * @return this builder + */ + public Builder withIndexName(String name) { + this.indexName = name; + return this; + } + + /** + * Configures the Redis key prefix to use (default: "embedding:"). + * @param prefix the prefix to use + * @return this builder + */ + public Builder withPrefix(String prefix) { + this.prefix = prefix; + return this; + } + + /** + * Configures the Redis content field name to use. + * @param name the content field name to use + * @return this builder + */ + public Builder withContentFieldName(String name) { + this.contentFieldName = name; + return this; + } + + /** + * Configures the Redis embedding field name to use. + * @param name the embedding field name to use + * @return this builder + */ + public Builder withEmbeddingFieldName(String name) { + this.embeddingFieldName = name; + return this; + } + + /** + * Configures the Redis vector algorithmto use. + * @param algorithm the vector algorithm to use + * @return this builder + */ + public Builder withVectorAlgorithm(Algorithm algorithm) { + this.vectorAlgorithm = algorithm; + return this; + } + + public Builder withMetadataFields(MetadataField... fields) { + return withMetadataFields(Arrays.asList(fields)); + } + + public Builder withMetadataFields(List fields) { + this.metadataFields = fields; + return this; + } + + /** + * {@return the immutable configuration} + */ + public RedisVectorStoreConfig build() { + + return new RedisVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java index f1e96534c..c2c0901ba 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; + import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag; import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric; +import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; @@ -27,16 +39,6 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import java.util.Arrays; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; -import org.springframework.ai.vectorstore.filter.Filter.Expression; -import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; -import org.springframework.ai.vectorstore.filter.Filter.Value; - /** * @author Julien Ruaux */ diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java index 44497602d..124d76d58 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,8 +23,13 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import com.redis.testcontainers.RedisStackContainer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -40,11 +44,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.redis.testcontainers.RedisStackContainer; -import redis.clients.jedis.JedisPooled; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Julien Ruaux @@ -93,24 +94,24 @@ class RedisVectorStoreIT { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).isEmpty(); @@ -120,7 +121,7 @@ class RedisVectorStoreIT { @Test void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -174,7 +175,7 @@ class RedisVectorStoreIT { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -215,11 +216,11 @@ class RedisVectorStoreIT { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -237,7 +238,7 @@ class RedisVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java index 64b30ce05..2d2ed538c 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import com.redis.testcontainers.RedisStackContainer; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -44,15 +51,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import com.redis.testcontainers.RedisStackContainer; - -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import redis.clients.jedis.JedisPooled; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -65,6 +65,11 @@ public class RedisVectorStoreObservationIT { static RedisStackContainer redisContainer = new RedisStackContainer( RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -80,11 +85,6 @@ public class RedisVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) - .withUserConfiguration(Config.class) - .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); - @BeforeEach void cleanDatabase() { this.contextRunner.run(context -> context.getBean(RedisVectorStore.class).getJedis().flushAll()); @@ -93,13 +93,13 @@ public class RedisVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-typesense-store/pom.xml b/vector-stores/spring-ai-typesense-store/pom.xml index b43f7711d..14087b76b 100644 --- a/vector-stores/spring-ai-typesense-store/pom.xml +++ b/vector-stores/spring-ai-typesense-store/pom.xml @@ -1,4 +1,20 @@ + + diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java index 0f19340d8..5706dd154 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.vectorstore; import org.springframework.ai.vectorstore.filter.Filter; @@ -40,7 +56,7 @@ public class TypesenseFilterExpressionConverter extends AbstractFilterExpression return " "; // in typesense "IN" operator looks like -> country: [USA, UK] case NIN: return " != "; // in typesense "NIN" operator looks like -> country: - // !=[USA, UK] + // !=[USA, UK] default: throw new RuntimeException("Not supported expression type:" + exp.type()); } diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index c27161a0f..1a7d0b36d 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -23,8 +23,20 @@ import java.util.Optional; import java.util.stream.IntStream; import java.util.stream.Stream; +import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.typesense.api.Client; +import org.typesense.api.FieldTypes; +import org.typesense.model.CollectionResponse; +import org.typesense.model.CollectionSchema; +import org.typesense.model.DeleteDocumentsParameters; +import org.typesense.model.Field; +import org.typesense.model.ImportDocumentsParameters; +import org.typesense.model.MultiSearchCollectionParameters; +import org.typesense.model.MultiSearchResult; +import org.typesense.model.MultiSearchSearchesParameter; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; @@ -38,18 +50,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationCont import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import org.typesense.api.Client; -import org.typesense.api.FieldTypes; -import org.typesense.model.CollectionResponse; -import org.typesense.model.CollectionSchema; -import org.typesense.model.DeleteDocumentsParameters; -import org.typesense.model.Field; -import org.typesense.model.ImportDocumentsParameters; -import org.typesense.model.MultiSearchCollectionParameters; -import org.typesense.model.MultiSearchResult; -import org.typesense.model.MultiSearchSearchesParameter; - -import io.micrometer.observation.ObservationRegistry; /** * @author Pablo Sanchidrian Herrera @@ -58,8 +58,6 @@ import io.micrometer.observation.ObservationRegistry; */ public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean { - private static final Logger logger = LoggerFactory.getLogger(TypesenseVectorStore.class); - /** * The name of the field that contains the document ID. It is mandatory to set "id" as * the field name because that is the name that typesense is going to look for. @@ -78,88 +76,20 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme public static final int INVALID_EMBEDDING_DIMENSION = -1; + private static final Logger logger = LoggerFactory.getLogger(TypesenseVectorStore.class); + + public final FilterExpressionConverter filterExpressionConverter = new TypesenseFilterExpressionConverter(); + private final Client client; private final EmbeddingModel embeddingModel; private final TypesenseVectorStoreConfig config; - public final FilterExpressionConverter filterExpressionConverter = new TypesenseFilterExpressionConverter(); - private final boolean initializeSchema; private final BatchingStrategy batchingStrategy; - public static class TypesenseVectorStoreConfig { - - private final String collectionName; - - private final int embeddingDimension; - - public TypesenseVectorStoreConfig(String collectionName, int embeddingDimension) { - this.collectionName = collectionName; - this.embeddingDimension = embeddingDimension; - } - - /** - * {@return the default config} - */ - public static TypesenseVectorStoreConfig defaultConfig() { - return builder().build(); - } - - private TypesenseVectorStoreConfig(Builder builder) { - this.collectionName = builder.collectionName; - this.embeddingDimension = builder.embeddingDimension; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - - return new Builder(); - } - - public static class Builder { - - private String collectionName; - - private int embeddingDimension; - - /** - * Set the collection name. - * @param collectionName The collection name. - * @return The builder. - */ - public Builder withCollectionName(String collectionName) { - this.collectionName = collectionName; - return this; - } - - /** - * Set the embedding dimension. - * @param embeddingDimension The embedding dimension. - * @return The builder. - */ - public Builder withEmbeddingDimension(int embeddingDimension) { - this.embeddingDimension = embeddingDimension; - return this; - } - - /** - * Build the configuration. - * @return The configuration. - */ - public TypesenseVectorStoreConfig build() { - return new TypesenseVectorStoreConfig(this); - } - - } - - } - public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel) { this(client, embeddingModel, TypesenseVectorStoreConfig.defaultConfig(), false); } @@ -396,4 +326,74 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); } + public static class TypesenseVectorStoreConfig { + + private final String collectionName; + + private final int embeddingDimension; + + public TypesenseVectorStoreConfig(String collectionName, int embeddingDimension) { + this.collectionName = collectionName; + this.embeddingDimension = embeddingDimension; + } + + private TypesenseVectorStoreConfig(Builder builder) { + this.collectionName = builder.collectionName; + this.embeddingDimension = builder.embeddingDimension; + } + + /** + * {@return the default config} + */ + public static TypesenseVectorStoreConfig defaultConfig() { + return builder().build(); + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + + return new Builder(); + } + + public static class Builder { + + private String collectionName; + + private int embeddingDimension; + + /** + * Set the collection name. + * @param collectionName The collection name. + * @return The builder. + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Set the embedding dimension. + * @param embeddingDimension The embedding dimension. + * @return The builder. + */ + public Builder withEmbeddingDimension(int embeddingDimension) { + this.embeddingDimension = embeddingDimension; + return this; + } + + /** + * Build the configuration. + * @return The configuration. + */ + public TypesenseVectorStoreConfig build() { + return new TypesenseVectorStoreConfig(this); + } + + } + + } + } diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java index ac27de2a8..ea06769bc 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java index 044b0114a..23cd9f03b 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,7 +16,23 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -27,21 +43,6 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -79,7 +80,7 @@ public class TypesenseVectorStoreIT { @Test void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")); @@ -127,10 +128,10 @@ public class TypesenseVectorStoreIT { @Test void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); Map info = ((TypesenseVectorStore) vectorStore).getCollectionInfo(); @@ -146,7 +147,7 @@ public class TypesenseVectorStoreIT { @Test void searchWithFilters() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -201,11 +202,11 @@ public class TypesenseVectorStoreIT { @Test void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -221,7 +222,7 @@ public class TypesenseVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java index c5fdc8810..fefde3d3e 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,17 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.typesense.api.Client; +import org.typesense.api.Configuration; +import org.typesense.resources.Node; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -41,16 +50,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.typesense.api.Client; -import org.typesense.api.Configuration; -import org.typesense.resources.Node; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -87,13 +88,13 @@ public class TypesenseVectorStoreObservationIT { @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() diff --git a/vector-stores/spring-ai-weaviate-store/pom.xml b/vector-stores/spring-ai-weaviate-store/pom.xml index c9359ca08..c1ea5ff95 100644 --- a/vector-stores/spring-ai-weaviate-store/pom.xml +++ b/vector-stores/spring-ai-weaviate-store/pom.xml @@ -1,4 +1,20 @@ + + 4.0.0 diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java index 5dcb1caeb..08eb2f29f 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Date; @@ -37,11 +38,11 @@ import org.springframework.util.Assert; */ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionConverter { - private boolean mapIntegerToNumberValue = true; - // https://weaviate.io/developers/weaviate/api/graphql/filters#special-cases private static final List SYSTEM_IDENTIFIERS = List.of("id", "_creationTimeUnix", "_lastUpdateTimeUnix"); + private boolean mapIntegerToNumberValue = true; + private List allowedIdentifierNames; public WeaviateFilterExpressionConverter(List allowedIdentifierNames) { @@ -189,4 +190,4 @@ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionC context); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 0c381b4e7..58a831eb1 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,25 +24,8 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.BatchingStrategy; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.model.EmbeddingUtils; -import org.springframework.ai.observation.conventions.VectorStoreProvider; -import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; -import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; -import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; - import io.micrometer.observation.ObservationRegistry; import io.weaviate.client.WeaviateClient; import io.weaviate.client.base.Result; @@ -61,6 +44,23 @@ import io.weaviate.client.v1.graphql.query.builder.GetBuilder.GetBuilderBuilder; import io.weaviate.client.v1.graphql.query.fields.Field; import io.weaviate.client.v1.graphql.query.fields.Fields; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + /** * A VectorStore implementation backed by Weaviate vector database. * @@ -130,164 +130,6 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { */ private final ObjectMapper objectMapper = new ObjectMapper(); - /** - * Configuration class for the WeaviateVectorStore. - */ - public static final class WeaviateVectorStoreConfig { - - public record MetadataField(String name, Type type) { - public enum Type { - - TEXT, NUMBER, BOOLEAN - - } - - public static MetadataField text(String name) { - return new MetadataField(name, Type.TEXT); - } - - public static MetadataField number(String name) { - return new MetadataField(name, Type.NUMBER); - } - - public static MetadataField bool(String name) { - return new MetadataField(name, Type.BOOLEAN); - } - - } - - /** - * https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-consistency-strategies - */ - public enum ConsistentLevel { - - /** - * Write must receive an acknowledgement from at least one replica node. This - * is the fastest (most available), but least consistent option. - */ - ONE, - - /** - * Write must receive an acknowledgement from at least QUORUM replica nodes. - * QUORUM is calculated as n / 2 + 1, where n is the number of replicas. - */ - QUORUM, - - /** - * Write must receive an acknowledgement from all replica nodes. This is the - * most consistent, but 'slowest'. - */ - ALL - - } - - private final String weaviateObjectClass; - - private final ConsistentLevel consistencyLevel; - - /** - * Known metadata fields to add as a fields to the Weaviate schema. You can add - * arbitrary metadata with your documents but only the metadata fields listed here - * can be used in the expression filters. - */ - private final List filterMetadataFields; - - private final Map headers; - - /** - * Constructor using the builder. - * @param builder The configuration builder. - */ - public WeaviateVectorStoreConfig(Builder builder) { - this.weaviateObjectClass = builder.objectClass; - this.consistencyLevel = builder.consistencyLevel; - this.filterMetadataFields = builder.filterMetadataFields; - this.headers = builder.headers; - } - - /** - * Start building a new configuration. - * @return The entry point for creating a new configuration. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * {@return the default config} - */ - public static WeaviateVectorStoreConfig defaultConfig() { - return builder().build(); - } - - public static class Builder { - - private String objectClass = "SpringAiWeaviate"; - - private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; - - private List filterMetadataFields = List.of(); - - private Map headers = Map.of(); - - private Builder() { - } - - /** - * Weaviate known, filterable metadata fields. - * @param filterMetadataFields known metadata fields to use. - * @return this builder. - */ - public Builder withFilterableMetadataFields(List filterMetadataFields) { - Assert.notNull(filterMetadataFields, "The filterMetadataFields can not be null."); - this.filterMetadataFields = filterMetadataFields; - return this; - } - - /** - * Weaviate config headers. - * @param headers config headers to use. - * @return this builder. - */ - public Builder withHeaders(Map headers) { - Assert.notNull(headers, "The headers can not be null."); - this.headers = headers; - return this; - } - - /** - * Weaviate objectClass. - * @param objectClass objectClass to use. - * @return this builder. - */ - public Builder withObjectClass(String objectClass) { - Assert.hasText(objectClass, "The objectClass can not be empty."); - this.objectClass = objectClass; - return this; - } - - /** - * Weaviate consistencyLevel. - * @param consistencyLevel consistencyLevel to use. - * @return this builder. - */ - public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { - Assert.notNull(consistencyLevel, "The consistencyLevel can not be null."); - this.consistencyLevel = consistencyLevel; - return this; - } - - /** - * {@return the immutable configuration} - */ - public WeaviateVectorStoreConfig build() { - return new WeaviateVectorStoreConfig(this); - } - - } - - } - /** * Constructs a new WeaviateVectorStore. * @param vectorStoreConfig The configuration for the store. @@ -462,7 +304,7 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { .build()) .limit(request.getTopK()) .withWhereFilter(WhereArgument.builder().build()) // adds an empty 'where:{}' - // placeholder. + // placeholder. .fields(Fields.builder().fields(this.weaviateSimilaritySearchFields).build()); String graphQLQuery = queryBuilder.build().buildQuery(); @@ -554,4 +396,163 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { .withCollectionName(this.weaviateObjectClass); } -} \ No newline at end of file + /** + * Configuration class for the WeaviateVectorStore. + */ + public static final class WeaviateVectorStoreConfig { + + private final String weaviateObjectClass; + + private final ConsistentLevel consistencyLevel; + + /** + * Known metadata fields to add as a fields to the Weaviate schema. You can add + * arbitrary metadata with your documents but only the metadata fields listed here + * can be used in the expression filters. + */ + private final List filterMetadataFields; + + private final Map headers; + + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + public WeaviateVectorStoreConfig(Builder builder) { + this.weaviateObjectClass = builder.objectClass; + this.consistencyLevel = builder.consistencyLevel; + this.filterMetadataFields = builder.filterMetadataFields; + this.headers = builder.headers; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static WeaviateVectorStoreConfig defaultConfig() { + return builder().build(); + } + + /** + * https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-consistency-strategies + */ + public enum ConsistentLevel { + + /** + * Write must receive an acknowledgement from at least one replica node. This + * is the fastest (most available), but least consistent option. + */ + ONE, + + /** + * Write must receive an acknowledgement from at least QUORUM replica nodes. + * QUORUM is calculated as n / 2 + 1, where n is the number of replicas. + */ + QUORUM, + + /** + * Write must receive an acknowledgement from all replica nodes. This is the + * most consistent, but 'slowest'. + */ + ALL + + } + + public record MetadataField(String name, Type type) { + + public static MetadataField text(String name) { + return new MetadataField(name, Type.TEXT); + } + + public static MetadataField number(String name) { + return new MetadataField(name, Type.NUMBER); + } + + public static MetadataField bool(String name) { + return new MetadataField(name, Type.BOOLEAN); + } + + public enum Type { + + TEXT, NUMBER, BOOLEAN + + } + + } + + public static class Builder { + + private String objectClass = "SpringAiWeaviate"; + + private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; + + private List filterMetadataFields = List.of(); + + private Map headers = Map.of(); + + private Builder() { + } + + /** + * Weaviate known, filterable metadata fields. + * @param filterMetadataFields known metadata fields to use. + * @return this builder. + */ + public Builder withFilterableMetadataFields(List filterMetadataFields) { + Assert.notNull(filterMetadataFields, "The filterMetadataFields can not be null."); + this.filterMetadataFields = filterMetadataFields; + return this; + } + + /** + * Weaviate config headers. + * @param headers config headers to use. + * @return this builder. + */ + public Builder withHeaders(Map headers) { + Assert.notNull(headers, "The headers can not be null."); + this.headers = headers; + return this; + } + + /** + * Weaviate objectClass. + * @param objectClass objectClass to use. + * @return this builder. + */ + public Builder withObjectClass(String objectClass) { + Assert.hasText(objectClass, "The objectClass can not be empty."); + this.objectClass = objectClass; + return this; + } + + /** + * Weaviate consistencyLevel. + * @param consistencyLevel consistencyLevel to use. + * @return this builder. + */ + public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { + Assert.notNull(consistencyLevel, "The consistencyLevel can not be null."); + this.consistencyLevel = consistencyLevel; + return this; + } + + /** + * {@return the immutable configuration} + */ + public WeaviateVectorStoreConfig build() { + return new WeaviateVectorStoreConfig(this); + } + + } + + } + +} diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java index fc004b3e1..53275f567 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.List; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java index 81e78bc5c..3dbfcdd93 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.testcontainers.utility.DockerImageName; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index 6f57d7a5d..b474cdaed 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +23,14 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -35,13 +41,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; -import io.weaviate.client.Config; -import io.weaviate.client.WeaviateClient; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -78,32 +79,32 @@ public class WeaviateVectorStoreIT { } private void resetCollection(VectorStore vectorStore) { - vectorStore.delete(documents.stream().map(Document::getId).toList()); + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); } @Test public void addAndSearch() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); // Remove all documents from the store - vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); assertThat(results).hasSize(0); @@ -113,7 +114,7 @@ public class WeaviateVectorStoreIT { @Test public void searchWithFilters() throws InterruptedException { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -167,7 +168,7 @@ public class WeaviateVectorStoreIT { @Test public void documentUpdate() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -210,13 +211,13 @@ public class WeaviateVectorStoreIT { @Test public void searchWithThreshold() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); resetCollection(vectorStore); - vectorStore.add(documents); + vectorStore.add(this.documents); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); @@ -234,7 +235,7 @@ public class WeaviateVectorStoreIT { assertThat(results).hasSize(1); Document resultDoc = results.get(0); - assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); @@ -267,4 +268,4 @@ public class WeaviateVectorStoreIT { } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java index 35f9c4d59..17b54c188 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.vectorstore; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.ai.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.weaviate.WeaviateContainer; + import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; @@ -38,15 +46,8 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.weaviate.WeaviateContainer; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistry; -import io.micrometer.observation.tck.TestObservationRegistryAssert; -import io.weaviate.client.WeaviateClient; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov @@ -59,6 +60,9 @@ public class WeaviateVectorStoreObservationIT { static WeaviateContainer weaviateContainer = new WeaviateContainer(WeaviateImage.DEFAULT_IMAGE) .waitingFor(Wait.forHttp("/v1/.well-known/ready").forPort(8080)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class); + List documents = List.of( new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document(getText("classpath:/test/data/time.shelter.txt")), @@ -74,19 +78,16 @@ public class WeaviateVectorStoreObservationIT { } } - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(Config.class); - @Test void observationVectorStoreAddAndQueryOperations() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); - vectorStore.add(documents); + vectorStore.add(this.documents); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation()