diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/pom.xml index 342a6b118..e3a171e2d 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/pom.xml @@ -34,6 +34,14 @@ + + + org.springframework.ai + spring-ai-autoconfigure-retry + ${project.parent.version} + true + + org.springframework.ai spring-ai-autoconfigure-model-tool diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfigurationTests.java index de27c35c2..6e3fe39d8 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfigurationTests.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfigurationTests.java @@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure; import org.junit.jupiter.api.Test; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -41,8 +42,9 @@ public class OllamaChatAutoConfigurationTests { "spring.ai.ollama.chat.options.topP=0.56", "spring.ai.ollama.chat.options.topK=123") // @formatter:on - .withConfiguration( - AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class)) + + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(OllamaChatProperties.class); var connectionProperties = context.getBean(OllamaConnectionProperties.class); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java index 6f2443278..bc6f2a6d3 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java @@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure; import org.junit.jupiter.api.Test; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -26,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Alexandros Pappas * @since 0.8.0 */ public class OllamaEmbeddingAutoConfigurationTests { @@ -41,8 +43,9 @@ public class OllamaEmbeddingAutoConfigurationTests { "spring.ai.ollama.embedding.options.topK=13" // @formatter:on ) - .withConfiguration( - AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaEmbeddingAutoConfiguration.class)) + + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class); var connectionProperties = context.getBean(OllamaConnectionProperties.class); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 8d22df6dd..44dc45347 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -65,8 +65,13 @@ import org.springframework.ai.ollama.api.common.OllamaApiConstants; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; + import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.JsonParser; + +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; + import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -129,27 +134,32 @@ public class OllamaChatModel implements ChatModel { private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + private final RetryTemplate retryTemplate; + public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, - new DefaultToolExecutionEligibilityPredicate()); + new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, - ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) { + Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null"); Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + this.retryTemplate = retryTemplate; initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); } @@ -237,7 +247,7 @@ public class OllamaChatModel implements ChatModel { this.observationRegistry) .observe(() -> { - OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request); + OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request)); List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message() @@ -540,6 +550,8 @@ public class OllamaChatModel implements ChatModel { private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults(); + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + private Builder() { } @@ -574,13 +586,20 @@ public class OllamaChatModel implements ChatModel { return this; } + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + public OllamaChatModel build() { if (this.toolCallingManager != null) { return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, - this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate); + this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate, + this.retryTemplate); } return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, - this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate); + this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate, + this.retryTemplate); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index e0ffc06c3..b481386a4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -51,6 +51,7 @@ import org.springframework.web.reactive.function.client.WebClient; * @author Christian Tzolov * @author Thomas Vitale * @author Jonghoon Park + * @author Alexandros Pappas * @since 0.8.0 */ // @formatter:off @@ -64,6 +65,9 @@ public final class OllamaApi { private static final Log logger = LogFactory.getLog(OllamaApi.class); + + private static final String DEFAULT_BASE_URL = "http://localhost:11434"; + private final RestClient restClient; private final WebClient webClient; @@ -77,11 +81,13 @@ public final class OllamaApi { */ private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + Consumer defaultHeaders = headers -> { headers.setContentType(MediaType.APPLICATION_JSON); headers.setAccept(List.of(MediaType.APPLICATION_JSON)); }; + this.restClient = restClientBuilder .clone() .baseUrl(baseUrl) @@ -89,6 +95,7 @@ public final class OllamaApi { .defaultStatusHandler(responseErrorHandler) .build(); + this.webClient = webClientBuilder .clone() .baseUrl(baseUrl) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index ef149203f..f8ec3091c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -36,6 +36,7 @@ import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.tool.MockWeatherService; import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -120,6 +121,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { return OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index b322a82d7..ae79f6735 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -52,8 +52,12 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; + import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.annotation.Tool; + +import org.springframework.ai.retry.RetryUtils; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -371,6 +375,7 @@ class OllamaChatModelIT extends BaseOllamaIT { .pullModelStrategy(PullModelStrategy.WHEN_MISSING) .additionalModels(List.of(ADDITIONAL_MODEL)) .build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 28064bb77..3174c459a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -16,6 +16,7 @@ package org.springframework.ai.ollama; +import java.time.Duration; import java.util.List; import org.junit.jupiter.api.Test; @@ -27,11 +28,17 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -86,9 +93,23 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT { @Bean public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { + RetryTemplate retryTemplate = RetryTemplate.builder() + .maxAttempts(1) + .retryOn(TransientAiException.class) + .fixedBackoff(Duration.ofSeconds(1)) + .withListener(new RetryListener() { + + @Override + public void onError(RetryContext context, + RetryCallback callback, Throwable throwable) { + logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); + } + }) + .build(); return OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(retryTemplate) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 916a364ba..0d8b6a0b7 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -34,6 +34,7 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -47,6 +48,7 @@ import static org.springframework.ai.chat.observation.ChatModelObservationDocume * Integration tests for observation instrumentation in {@link OllamaChatModel}. * * @author Thomas Vitale + * @author Alexandros Pappas */ @SpringBootTest(classes = OllamaChatModelObservationIT.Config.class) public class OllamaChatModelObservationIT extends BaseOllamaIT { @@ -169,7 +171,11 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT { @Bean public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { - return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + return OllamaChatModel.builder() + .ollamaApi(ollamaApi) + .observationRegistry(observationRegistry) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index fdb3c43cb..1cb17781b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -35,6 +35,7 @@ import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; +import org.springframework.ai.retry.RetryUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.*; @@ -82,6 +83,7 @@ class OllamaChatModelTests { () -> OllamaChatModel.builder() .ollamaApi(this.ollamaApi) .defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .modelManagementOptions(null) .build()); assertEquals("modelManagementOptions must not be null", exception.getMessage()); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 59baa37be..dbc65e1fb 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -28,18 +28,21 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.retry.RetryUtils; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas */ class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); @Test @@ -146,6 +149,7 @@ class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content")); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 2220bf226..e82ecb9ab 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ import org.testcontainers.utility.DockerImageName; */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.6.7"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java new file mode 100644 index 000000000..bb1c1bb22 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java @@ -0,0 +1,97 @@ +package org.springframework.ai.ollama; + +import java.time.Instant; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * Tests for the OllamaRetryTests class. + * + * @author Alexandros Pappas + */ +@ExtendWith(MockitoExtension.class) +class OllamaRetryTests { + + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + @Mock + private OllamaApi ollamaApi; + + private OllamaChatModel chatModel; + + @BeforeEach + public void beforeEach() { + this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.chatModel = OllamaChatModel.builder() + .ollamaApi(this.ollamaApi) + .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(this.retryTemplate) + .build(); + } + + @Test + void ollamaChatTransientError() { + String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?"; + var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(), + OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Response").build(), null, true, + null, null, null, null, null, null); + + when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java new file mode 100644 index 000000000..5675e4c4a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -0,0 +1,153 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; +import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.OllamaEmbeddingModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.management.ModelManagementOptions; +import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * {@link AutoConfiguration Auto-configuration} for Ollama Chat Client. + * + * @author Christian Tzolov + * @author EddĂș MelĂ©ndez + * @author Thomas Vitale + * @since 0.8.0 + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class }) +@ConditionalOnClass(OllamaApi.class) +@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class, + OllamaConnectionProperties.class, OllamaInitializationProperties.class }) +@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class, + WebClientAutoConfiguration.class }) +public class OllamaAutoConfiguration { + + @Bean + @ConditionalOnMissingBean(OllamaConnectionDetails.class) + public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) { + return new PropertiesOllamaConnectionDetails(properties); + } + + @Bean + @ConditionalOnMissingBean + public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails, + ObjectProvider restClientBuilderProvider, + ObjectProvider webClientBuilderProvider) { + return new OllamaApi(connectionDetails.getBaseUrl(), + restClientBuilderProvider.getIfAvailable(RestClient::builder), + webClientBuilderProvider.getIfAvailable(WebClient::builder)); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties, + OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, + ObjectProvider observationConvention, RetryTemplate retryTemplate) { + var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy() + : PullModelStrategy.NEVER; + + var chatModel = OllamaChatModel.builder() + .ollamaApi(ollamaApi) + .defaultOptions(properties.getOptions()) + .toolCallingManager(toolCallingManager) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .modelManagementOptions( + new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(), + initProperties.getTimeout(), initProperties.getMaxRetries())) + .retryTemplate(retryTemplate) + .build(); + + observationConvention.ifAvailable(chatModel::setObservationConvention); + + return chatModel; + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties, + OllamaInitializationProperties initProperties, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude() + ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER; + + var embeddingModel = OllamaEmbeddingModel.builder() + .ollamaApi(ollamaApi) + .defaultOptions(properties.getOptions()) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy, + initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(), + initProperties.getMaxRetries())) + .build(); + + observationConvention.ifAvailable(embeddingModel::setObservationConvention); + + return embeddingModel; + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) { + DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); + manager.setApplicationContext(context); + return manager; + } + + static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails { + + private final OllamaConnectionProperties properties; + + PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) { + this.properties = properties; + } + + @Override + public String getBaseUrl() { + return this.properties.getBaseUrl(); + } + + } + +}