feat(ollama): add retry template integration to OllamaChatModel

* Update tests that are supposed to fail to not use retry
* Upgrade ot use Ollama 0.6.7

Signed-off-by: Alexandros Pappas <alexandros.pappas@yiluhub.com>
This commit is contained in:
Alexandros Pappas
2024-12-02 16:10:36 +01:00
committed by Mark Pollack
parent 6ef15bdd74
commit cfbefee6a2
14 changed files with 340 additions and 11 deletions

View File

@@ -34,6 +34,14 @@
</dependency>
<!-- Spring AI auto configurations -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-retry</artifactId>
<version>${project.parent.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-tool</artifactId>

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure;
import org.junit.jupiter.api.Test;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,8 +42,9 @@ public class OllamaChatAutoConfigurationTests {
"spring.ai.ollama.chat.options.topP=0.56",
"spring.ai.ollama.chat.options.topK=123")
// @formatter:on
.withConfiguration(
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
.run(context -> {
var chatProperties = context.getBean(OllamaChatProperties.class);
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure;
import org.junit.jupiter.api.Test;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -26,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
* @author Alexandros Pappas
* @since 0.8.0
*/
public class OllamaEmbeddingAutoConfigurationTests {
@@ -41,8 +43,9 @@ public class OllamaEmbeddingAutoConfigurationTests {
"spring.ai.ollama.embedding.options.topK=13"
// @formatter:on
)
.withConfiguration(
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaEmbeddingAutoConfiguration.class))
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class);
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

View File

@@ -65,8 +65,13 @@ import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@@ -129,27 +134,32 @@ public class OllamaChatModel implements ChatModel {
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
private final RetryTemplate retryTemplate;
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
new DefaultToolExecutionEligibilityPredicate());
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.toolCallingManager = toolCallingManager;
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
this.retryTemplate = retryTemplate;
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}
@@ -237,7 +247,7 @@ public class OllamaChatModel implements ChatModel {
this.observationRegistry)
.observe(() -> {
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
: ollamaResponse.message()
@@ -540,6 +550,8 @@ public class OllamaChatModel implements ChatModel {
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
private Builder() {
}
@@ -574,13 +586,20 @@ public class OllamaChatModel implements ChatModel {
return this;
}
public Builder retryTemplate(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
return this;
}
public OllamaChatModel build() {
if (this.toolCallingManager != null) {
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
this.retryTemplate);
}
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
this.retryTemplate);
}
}

View File

@@ -51,6 +51,7 @@ import org.springframework.web.reactive.function.client.WebClient;
* @author Christian Tzolov
* @author Thomas Vitale
* @author Jonghoon Park
* @author Alexandros Pappas
* @since 0.8.0
*/
// @formatter:off
@@ -64,6 +65,9 @@ public final class OllamaApi {
private static final Log logger = LogFactory.getLog(OllamaApi.class);
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
private final RestClient restClient;
private final WebClient webClient;
@@ -77,11 +81,13 @@ public final class OllamaApi {
*/
private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
};
this.restClient = restClientBuilder
.clone()
.baseUrl(baseUrl)
@@ -89,6 +95,7 @@ public final class OllamaApi {
.defaultStatusHandler(responseErrorHandler)
.build();
this.webClient = webClientBuilder
.clone()
.baseUrl(baseUrl)

View File

@@ -36,6 +36,7 @@ import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.tool.MockWeatherService;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -120,6 +121,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

View File

@@ -52,8 +52,12 @@ import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -371,6 +375,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
.additionalModels(List.of(ADDITIONAL_MODEL))
.build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

View File

@@ -16,6 +16,7 @@
package org.springframework.ai.ollama;
import java.time.Duration;
import java.util.List;
import org.junit.jupiter.api.Test;
@@ -27,11 +28,17 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.MimeTypeUtils;
import static org.assertj.core.api.Assertions.assertThat;
@@ -86,9 +93,23 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT {
@Bean
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(1)
.retryOn(TransientAiException.class)
.fixedBackoff(Duration.ofSeconds(1))
.withListener(new RetryListener() {
@Override
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
}
})
.build();
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
.retryTemplate(retryTemplate)
.build();
}

View File

@@ -34,6 +34,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@ import static org.springframework.ai.chat.observation.ChatModelObservationDocume
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
*
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -169,7 +171,11 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT {
@Bean
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.observationRegistry(observationRegistry)
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}
}

View File

@@ -35,6 +35,7 @@ import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.retry.RetryUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
@@ -82,6 +83,7 @@ class OllamaChatModelTests {
() -> OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.modelManagementOptions(null)
.build());
assertEquals("modelManagementOptions must not be null", exception.getMessage());

View File

@@ -28,18 +28,21 @@ import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.retry.RetryUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
class OllamaChatRequestTests {
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
@Test
@@ -146,6 +149,7 @@ class OllamaChatRequestTests {
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));

View File

@@ -23,7 +23,7 @@ import org.testcontainers.utility.DockerImageName;
*/
public final class OllamaImage {
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.6.7");
private OllamaImage() {

View File

@@ -0,0 +1,97 @@
package org.springframework.ai.ollama;
import java.time.Instant;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.when;
/**
* Tests for the OllamaRetryTests class.
*
* @author Alexandros Pappas
*/
@ExtendWith(MockitoExtension.class)
class OllamaRetryTests {
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
private TestRetryListener retryListener;
private RetryTemplate retryTemplate;
@Mock
private OllamaApi ollamaApi;
private OllamaChatModel chatModel;
@BeforeEach
public void beforeEach() {
this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE;
this.retryListener = new TestRetryListener();
this.retryTemplate.registerListener(this.retryListener);
this.chatModel = OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
.retryTemplate(this.retryTemplate)
.build();
}
@Test
void ollamaChatTransientError() {
String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?";
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Response").build(), null, true,
null, null, null, null, null, null);
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
.thenThrow(new TransientAiException("Transient Error 1"))
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(expectedChatResponse);
var result = this.chatModel.call(new Prompt(promptText));
assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText()).isSameAs("Response");
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
}
private static class TestRetryListener implements RetryListener {
int onErrorRetryCount = 0;
int onSuccessRetryCount = 0;
@Override
public <T, E extends Throwable> void onSuccess(RetryContext context, RetryCallback<T, E> callback, T result) {
this.onSuccessRetryCount = context.getRetryCount();
}
@Override
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
Throwable throwable) {
this.onErrorRetryCount = context.getRetryCount();
}
}
}

View File

@@ -0,0 +1,153 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.autoconfigure.ollama;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
/**
* {@link AutoConfiguration Auto-configuration} for Ollama Chat Client.
*
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Thomas Vitale
* @since 0.8.0
*/
@AutoConfiguration(after = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class })
@ConditionalOnClass(OllamaApi.class)
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
OllamaConnectionProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class,
WebClientAutoConfiguration.class })
public class OllamaAutoConfiguration {
@Bean
@ConditionalOnMissingBean(OllamaConnectionDetails.class)
public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) {
return new PropertiesOllamaConnectionDetails(properties);
}
@Bean
@ConditionalOnMissingBean
public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
ObjectProvider<RestClient.Builder> restClientBuilderProvider,
ObjectProvider<WebClient.Builder> webClientBuilderProvider) {
return new OllamaApi(connectionDetails.getBaseUrl(),
restClientBuilderProvider.getIfAvailable(RestClient::builder),
webClientBuilderProvider.getIfAvailable(WebClient::builder));
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
: PullModelStrategy.NEVER;
var chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.toolCallingManager(toolCallingManager)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
initProperties.getTimeout(), initProperties.getMaxRetries()))
.retryTemplate(retryTemplate)
.build();
observationConvention.ifAvailable(chatModel::setObservationConvention);
return chatModel;
}
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;
var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(properties.getOptions())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
initProperties.getMaxRetries()))
.build();
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
return embeddingModel;
}
@Bean
@ConditionalOnMissingBean
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
manager.setApplicationContext(context);
return manager;
}
static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails {
private final OllamaConnectionProperties properties;
PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) {
this.properties = properties;
}
@Override
public String getBaseUrl() {
return this.properties.getBaseUrl();
}
}
}