Move native hints into individual modules

* Use aot.factories approach for registration
* Add tests
* final tweaks- Thanks Josh!
This commit is contained in:
Mark Pollack
2024-02-13 11:11:24 -05:00
parent ea7dce3833
commit 596f2b06c0
34 changed files with 609 additions and 207 deletions

View File

@@ -0,0 +1,43 @@
package org.springframework.ai.reader.pdf.aot;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import java.io.IOException;
import java.util.Set;
/**
* The PdfReaderRuntimeHints class is responsible for registering runtime hints for PDFBox
* resources.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class PdfReaderRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
try {
var resolver = new PathMatchingResourcePatternResolver();
var patterns = Set.of("/org/apache/pdfbox/resources/glyphlist/zapfdingbats.txt",
"/org/apache/pdfbox/resources/glyphlist/glyphlist.txt", "/org/apache/fontbox/cmap/**",
"/org/apache/pdfbox/resources/afm/**", "/org/apache/pdfbox/resources/glyphlist/**",
"/org/apache/pdfbox/resources/icc/**", "/org/apache/pdfbox/resources/text/**",
"/org/apache/pdfbox/resources/ttf/**", "/org/apache/pdfbox/resources/version.properties");
for (var pattern : patterns)
for (var resourceMatch : resolver.getResources(pattern))
hints.resources().registerResource(resourceMatch);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.reader.pdf.aot.PdfReaderRuntimeHints

View File

@@ -0,0 +1,36 @@
package org.springframework.ai.reader.pdf.aot;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource;
class PdfReaderRuntimeHintsTests {
@Test
void registerHints() {
RuntimeHints runtimeHints = new RuntimeHints();
PdfReaderRuntimeHints pdfReaderRuntimeHints = new PdfReaderRuntimeHints();
pdfReaderRuntimeHints.registerHints(runtimeHints, null);
Assertions.assertThat(runtimeHints)
.matches(resource().forResource("/org/apache/pdfbox/resources/glyphlist/zapfdingbats.txt"));
Assertions.assertThat(runtimeHints)
.matches(resource().forResource("/org/apache/pdfbox/resources/glyphlist/glyphlist.txt"));
// Assertions.assertThat(runtimeHints).matches(resource().forResource("/org/apache/pdfbox/resources/afm/**"));
// Assertions.assertThat(runtimeHints).matches(resource().forResource("/org/apache/pdfbox/resources/glyphlist/**"));
// Assertions.assertThat(runtimeHints).matches(resource().forResource("/org/apache/pdfbox/resources/icc/**"));
// Assertions.assertThat(runtimeHints).matches(resource().forResource("/org/apache/pdfbox/resources/text/**"));
// Assertions.assertThat(runtimeHints).matches(resource().forResource("/org/apache/pdfbox/resources/ttf/**"));
Assertions.assertThat(runtimeHints)
.matches(resource().forResource("/org/apache/pdfbox/resources/version.properties"));
}
}

View File

@@ -0,0 +1,45 @@
package org.springframework.ai.bedrock.aot;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
/**
* The BedrockRuntimeHints class is responsible for registering runtime hints for Bedrock
* AI API classes.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class BedrockRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(Ai21Jurassic2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(CohereChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
}
}

View File

@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.bedrock.aot.BedrockRuntimeHints

View File

@@ -0,0 +1,43 @@
package org.springframework.ai.bedrock.aot;
import org.junit.jupiter.api.Test;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.List;
import java.util.Set;
import java.util.Arrays;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
class BedrockRuntimeHintsTests {
@Test
void registerHints() {
RuntimeHints runtimeHints = new RuntimeHints();
BedrockRuntimeHints bedrockRuntimeHints = new BedrockRuntimeHints();
bedrockRuntimeHints.registerHints(runtimeHints, null);
List<Class> classList = Arrays.asList(Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class,
CohereEmbeddingBedrockApi.class, Llama2ChatBedrockApi.class, TitanChatBedrockApi.class,
TitanEmbeddingBedrockApi.class, AnthropicChatBedrockApi.class);
for (Class aClass : classList) {
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(aClass);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
}
}
}

View File

@@ -0,0 +1,30 @@
package org.springframework.ai.ollama.aot;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
/**
* The OllamaRuntimeHints class is responsible for registering runtime hints for Ollama AI
* API classes.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class OllamaRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(OllamaApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(OllamaOptions.class))
hints.reflection().registerType(tr, mcs);
}
}

View File

@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.vertex.aot.OllamaRuntimeHints

View File

@@ -0,0 +1,34 @@
package org.springframework.ai.ollama.aot;
import org.junit.jupiter.api.Test;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.*;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.*;
class OllamaRuntimeHintsTests {
@Test
void registerHints() {
RuntimeHints runtimeHints = new RuntimeHints();
OllamaRuntimeHints ollamaRuntimeHints = new OllamaRuntimeHints();
ollamaRuntimeHints.registerHints(runtimeHints, null);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaApi.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OllamaOptions.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
}
}

View File

@@ -0,0 +1,27 @@
package org.springframework.ai.openai.aot;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
/**
* The OpenAiRuntimeHints class is responsible for registering runtime hints for OpenAI
* API classes.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(OpenAiApi.class))
hints.reflection().registerType(tr, mcs);
}
}

View File

@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.openai.aot.OpenAiRuntimeHints

View File

@@ -0,0 +1,28 @@
package org.springframework.ai.openai.aot;
import org.junit.jupiter.api.Test;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
class OpenAiRuntimeHintsTests {
@Test
void registerHints() {
RuntimeHints runtimeHints = new RuntimeHints();
OpenAiRuntimeHints openAiRuntimeHints = new OpenAiRuntimeHints();
openAiRuntimeHints.registerHints(runtimeHints, null);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(OpenAiApi.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
}
}

View File

@@ -0,0 +1,27 @@
package org.springframework.ai.vertex.aot;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
/**
* The VertexRuntimeHints class is responsible for registering runtime hints for Vertex AI
* API classes.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class VertexRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(VertexAiApi.class))
hints.reflection().registerType(tr, mcs);
}
}

View File

@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.vertex.aot.VertexRuntimeHints

View File

@@ -0,0 +1,27 @@
package org.springframework.ai.vertex.aot;
import org.junit.jupiter.api.Test;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
class VertexRuntimeHintsTests {
@Test
void registerHints() {
RuntimeHints runtimeHints = new RuntimeHints();
VertexRuntimeHints vertexRuntimeHints = new VertexRuntimeHints();
vertexRuntimeHints.registerHints(runtimeHints, null);
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(VertexAiApi.class);
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass));
}
}
}

View File

@@ -50,6 +50,14 @@
<artifactId>reactor-core</artifactId>
</dependency>
<!-- Spring Framework -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-messaging</artifactId>

View File

@@ -0,0 +1,114 @@
package org.springframework.ai.aot;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.aot.hint.TypeReference;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import java.lang.reflect.Executable;
import java.util.*;
import java.util.stream.Collectors;
/**
* Native runtime hints. See other modules for their respective native runtime hints.
*
* @author Josh Long
* @author Christian Tzolov
* @author Mark Pollack
*/
public class AiRuntimeHints {
private static final Logger log = LoggerFactory.getLogger(AiRuntimeHints.class);
public static Set<TypeReference> findJsonAnnotatedClassesInPackage(String packageName) {
var classPathScanningCandidateComponentProvider = new ClassPathScanningCandidateComponentProvider(false);
var annotationTypeFilter = new AnnotationTypeFilter(JsonInclude.class);
classPathScanningCandidateComponentProvider.addIncludeFilter((metadataReader, metadataReaderFactory) -> {
try {
var clazz = Class.forName(metadataReader.getClassMetadata().getClassName());
return annotationTypeFilter.match(metadataReader, metadataReaderFactory)
|| !discoverJacksonAnnotatedTypesFromRootType(clazz).isEmpty();
}
catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
});
return classPathScanningCandidateComponentProvider//
.findCandidateComponents(packageName)//
.stream()//
.map(bd -> TypeReference.of(Objects.requireNonNull(bd.getBeanClassName())))//
.peek(tr -> {
if (log.isDebugEnabled())
log.debug("registering [" + tr.getName() + ']');
})
.collect(Collectors.toUnmodifiableSet());
}
public static Set<TypeReference> findJsonAnnotatedClassesInPackage(Class<?> packageClass) {
return findJsonAnnotatedClassesInPackage(packageClass.getPackageName());
}
private static boolean hasJacksonAnnotations(Class<?> type) {
var hasAnnotation = false;
var annotationsToFind = Set.of(JsonProperty.class, JsonInclude.class);
for (var annotationToFind : annotationsToFind) {
if (type.isAnnotationPresent(annotationToFind)) {
hasAnnotation = true;
}
var executables = new HashSet<Executable>();
executables.addAll(Set.of(type.getMethods()));
executables.addAll(Set.of(type.getConstructors()));
executables.addAll(Set.of(type.getDeclaredConstructors()));
for (var executable : executables) {
//
if (executable.isAnnotationPresent(annotationToFind)) {
hasAnnotation = true;
}
///
for (var p : executable.getParameters()) {
if (p.isAnnotationPresent(annotationToFind)) {
hasAnnotation = true;
}
}
}
if (type.getRecordComponents() != null) {
for (var r : type.getRecordComponents()) {
if (r.isAnnotationPresent(annotationToFind)) {
hasAnnotation = true;
}
}
}
for (var f : type.getFields()) {
if (f.isAnnotationPresent(annotationToFind)) {
hasAnnotation = true;
}
}
}
return hasAnnotation;
}
private static Set<Class<?>> discoverJacksonAnnotatedTypesFromRootType(Class<?> type) {
var jsonTypes = new HashSet<Class<?>>();
var classesToInspect = new HashSet<Class<?>>();
classesToInspect.add(type);
classesToInspect.addAll(Arrays.asList(type.getNestMembers()));
for (var n : classesToInspect) {
if (hasJacksonAnnotations(n)) {
jsonTypes.add(n);
}
}
return jsonTypes;
}
}

View File

@@ -0,0 +1,14 @@
package org.springframework.ai.aot;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.core.io.ClassPathResource;
public class KnuddelsRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken"));
}
}

View File

@@ -0,0 +1,27 @@
package org.springframework.ai.aot;
import org.springframework.ai.chat.messages.*;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.core.io.ClassPathResource;
import java.util.Set;
public class SpringAiCoreRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var chatTypes = Set.of(AbstractMessage.class, AssistantMessage.class, ChatMessage.class, FunctionMessage.class,
Message.class, MessageType.class, UserMessage.class, SystemMessage.class);
for (var c : chatTypes) {
hints.reflection().registerType(c);
}
for (var r : Set.of("antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4",
"embedding/embedding-model-dimensions.properties"))
hints.resources().registerResource(new ClassPathResource(r));
}
}

View File

@@ -0,0 +1,3 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.aot.SpringAiCoreRuntimeHints,\
org.springframework.ai.aot.KnuddelsRuntimeHints

View File

@@ -0,0 +1,47 @@
package org.springframework.ai.aot;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.TypeReference;
import org.springframework.util.Assert;
import java.util.Set;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
class AiRuntimeHintsTests {
@JsonInclude
static class TestApi {
static class FooBar {
}
record Foo(@JsonProperty("name") String name) {
}
@JsonInclude
enum Bar {
A, B
}
}
@Test
void discoverRelevantClasses() throws Exception {
var classes = AiRuntimeHints.findJsonAnnotatedClassesInPackage(TestApi.class);
var included = Set.of(TestApi.Bar.class, TestApi.Foo.class)
.stream()
.map(t -> TypeReference.of(t.getName()))
.collect(Collectors.toSet());
LogFactory.getLog(getClass()).info(classes);
Assert.state(classes.containsAll(included), "there should be all of the enumerated classes. ");
}
}

View File

@@ -0,0 +1,19 @@
package org.springframework.ai.aot;
import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.RuntimeHints;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource;
class KnuddelsRuntimeHintsTest {
@Test
void knuddels() {
var runtimeHints = new RuntimeHints();
var knuddels = new KnuddelsRuntimeHints();
knuddels.registerHints(runtimeHints, null);
assertThat(runtimeHints).matches(resource().forResource("com/knuddels/jtokkit/cl100k_base.tiktoken"));
}
}

View File

@@ -0,0 +1,19 @@
package org.springframework.ai.aot;
import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.RuntimeHints;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.resource;
class SpringAiCoreRuntimeHintsTest {
@Test
void core() {
var runtimeHints = new RuntimeHints();
var knuddels = new SpringAiCoreRuntimeHints();
knuddels.registerHints(runtimeHints, null);
assertThat(runtimeHints).matches(resource().forResource("embedding/embedding-model-dimensions.properties"));
}
}

View File

@@ -1,159 +0,0 @@
package org.springframework.ai.autoconfigure;
import com.fasterxml.jackson.annotation.JsonInclude;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.TypeReference;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/***
* Native hints
*
* @author Josh Long
*/
public class NativeHints implements RuntimeHintsRegistrar {
static final Logger log = LoggerFactory.getLogger(NativeHints.class);
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
for (var h : Set.of(new BedrockAiHints(), new VertexAiHints(), new OpenAiHints(), new PdfReaderHints(),
new KnuddelsHints(), new OllamaHints()))
h.registerHints(hints, classLoader);
hints.resources().registerResource(new ClassPathResource("embedding/embedding-model-dimensions.properties"));
}
private static Set<TypeReference> findJsonAnnotatedClasses(Class<?> packageClass) {
var packageName = packageClass.getPackageName();
var classPathScanningCandidateComponentProvider = new ClassPathScanningCandidateComponentProvider(false);
classPathScanningCandidateComponentProvider.addIncludeFilter(new AnnotationTypeFilter(JsonInclude.class));
return classPathScanningCandidateComponentProvider.findCandidateComponents(packageName)
.stream()
.map(bd -> TypeReference.of(Objects.requireNonNull(bd.getBeanClassName())))
.peek(tr -> {
if (log.isDebugEnabled())
log.debug("registering [" + tr.getName() + ']');
})
.collect(Collectors.toUnmodifiableSet());
}
static class VertexAiHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(VertexAiApi.class))
hints.reflection().registerType(tr, mcs);
}
}
static class OllamaHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(OllamaApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(OllamaOptions.class))
hints.reflection().registerType(tr, mcs);
}
}
static class BedrockAiHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(Ai21Jurassic2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(CohereChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(CohereEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(Llama2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(TitanChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(TitanEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(AnthropicChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
}
}
static class OpenAiHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(OpenAiApi.class))
hints.reflection().registerType(tr, mcs);
}
}
static class KnuddelsHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken"));
}
}
static class PdfReaderHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
try {
var resolver = new PathMatchingResourcePatternResolver();
var patterns = Set.of("/org/apache/pdfbox/resources/glyphlist/zapfdingbats.txt",
"/org/apache/pdfbox/resources/glyphlist/glyphlist.txt", "/org/apache/fontbox/cmap/**",
"/org/apache/pdfbox/resources/afm/**", "/org/apache/pdfbox/resources/glyphlist/**",
"/org/apache/pdfbox/resources/icc/**", "/org/apache/pdfbox/resources/text/**",
"/org/apache/pdfbox/resources/ttf/**", "/org/apache/pdfbox/resources/version.properties");
for (var pattern : patterns)
for (var resourceMatch : resolver.getResources(pattern))
hints.resources().registerResource(resourceMatch);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}

View File

@@ -17,10 +17,6 @@
package org.springframework.ai.autoconfigure.bedrock.anthropic;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatClient;
@@ -32,7 +28,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client.
@@ -47,7 +43,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockAnthropicChatProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockAnthropicChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockAnthropicChatAutoConfiguration {
@Bean

View File

@@ -16,14 +16,10 @@
package org.springframework.ai.autoconfigure.bedrock.cohere;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.cohere.BedrockCohereChatClient;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.LogitBias;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -31,7 +27,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client.
@@ -44,7 +40,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockCohereChatProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockCohereChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockCohereChatAutoConfiguration {
@Bean

View File

@@ -19,7 +19,6 @@ package org.springframework.ai.autoconfigure.bedrock.cohere;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingClient;
@@ -31,7 +30,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Embedding Client.
@@ -44,7 +42,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockCohereEmbeddingProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockCohereEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockCohereEmbeddingAutoConfiguration {
@Bean

View File

@@ -19,7 +19,6 @@ package org.springframework.ai.autoconfigure.bedrock.llama2;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient;
@@ -31,7 +30,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Llama2 Chat Client.
@@ -46,7 +44,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockLlama2ChatProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockLlama2ChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockLlama2ChatAutoConfiguration {
@Bean

View File

@@ -16,9 +16,6 @@
package org.springframework.ai.autoconfigure.bedrock.titan;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.titan.BedrockTitanChatClient;
@@ -30,7 +27,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client.
@@ -43,7 +40,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockTitanChatProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockTitanChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockTitanChatAutoConfiguration {
@Bean

View File

@@ -19,7 +19,6 @@ package org.springframework.ai.autoconfigure.bedrock.titan;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient;
@@ -31,7 +30,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportRuntimeHints;
/**
* {@link AutoConfiguration Auto-configuration} for Bedrock Titan Embedding Client.
@@ -44,7 +42,6 @@ import org.springframework.context.annotation.ImportRuntimeHints;
@EnableConfigurationProperties({ BedrockTitanEmbeddingProperties.class, BedrockAwsConnectionProperties.class })
@ConditionalOnProperty(prefix = BedrockTitanEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@Import(BedrockAwsConnectionConfiguration.class)
@ImportRuntimeHints(NativeHints.class)
public class BedrockTitanEmbeddingAutoConfiguration {
@Bean

View File

@@ -15,7 +15,6 @@
*/
package org.springframework.ai.autoconfigure.ollama;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.OllamaEmbeddingClient;
import org.springframework.ai.ollama.api.OllamaApi;
@@ -25,7 +24,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.web.client.RestClient;
/**
@@ -38,7 +36,6 @@ import org.springframework.web.client.RestClient;
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
OllamaConnectionProperties.class })
@ImportRuntimeHints(NativeHints.class)
public class OllamaAutoConfiguration {
@Bean

View File

@@ -16,12 +16,9 @@
package org.springframework.ai.autoconfigure.openai;
import java.util.List;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiEmbeddingClient;
import org.springframework.ai.openai.OpenAiImageClient;
@@ -34,17 +31,17 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import java.util.List;
@AutoConfiguration(after = { RestClientAutoConfiguration.class })
@ConditionalOnClass(OpenAiApi.class)
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class,
OpenAiEmbeddingProperties.class, OpenAiImageProperties.class })
@ImportRuntimeHints(NativeHints.class)
/**
* @author Christian Tzolov
*/

View File

@@ -15,26 +15,18 @@
*/
package org.springframework.ai.autoconfigure.stabilityai;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.web.client.RestClient;
/**
* @author Mark Pollack
* @since 0.8.0
*/
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(StabilityAiApi.class)
@EnableConfigurationProperties({ StabilityAiImageProperties.class })
@ImportRuntimeHints(NativeHints.class)
public class StabilityAiImageAutoConfiguration {
@Bean

View File

@@ -16,22 +16,19 @@
package org.springframework.ai.autoconfigure.vertexai;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.ai.vertex.VertexAiEmbeddingClient;
import org.springframework.ai.vertex.VertexAiChatClient;
import org.springframework.ai.vertex.VertexAiEmbeddingClient;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.web.client.RestClient;
@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(VertexAiApi.class)
@ImportRuntimeHints(NativeHints.class)
@EnableConfigurationProperties({ VertexAiConnectionProperties.class, VertexAiChatProperties.class,
VertexAiEmbeddingProperties.class })
public class VertexAiAutoConfiguration {