feat(ollama): add retry template integration to OllamaChatModel
* Update tests that are supposed to fail to not use retry * Upgrade ot use Ollama 0.6.7 Signed-off-by: Alexandros Pappas <alexandros.pappas@yiluhub.com>
This commit is contained in:
committed by
Mark Pollack
parent
6ef15bdd74
commit
cfbefee6a2
@@ -34,6 +34,14 @@
|
||||
</dependency>
|
||||
|
||||
<!-- Spring AI auto configurations -->
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-autoconfigure-retry</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-autoconfigure-model-tool</artifactId>
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
@@ -41,8 +42,9 @@ public class OllamaChatAutoConfigurationTests {
|
||||
"spring.ai.ollama.chat.options.topP=0.56",
|
||||
"spring.ai.ollama.chat.options.topK=123")
|
||||
// @formatter:on
|
||||
.withConfiguration(
|
||||
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
|
||||
|
||||
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
|
||||
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
|
||||
.run(context -> {
|
||||
var chatProperties = context.getBean(OllamaChatProperties.class);
|
||||
var connectionProperties = context.getBean(OllamaConnectionProperties.class);
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.model.ollama.autoconfigure;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
@@ -26,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Alexandros Pappas
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class OllamaEmbeddingAutoConfigurationTests {
|
||||
@@ -41,8 +43,9 @@ public class OllamaEmbeddingAutoConfigurationTests {
|
||||
"spring.ai.ollama.embedding.options.topK=13"
|
||||
// @formatter:on
|
||||
)
|
||||
.withConfiguration(
|
||||
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaEmbeddingAutoConfiguration.class))
|
||||
|
||||
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
|
||||
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
|
||||
.run(context -> {
|
||||
var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class);
|
||||
var connectionProperties = context.getBean(OllamaConnectionProperties.class);
|
||||
|
||||
@@ -65,8 +65,13 @@ import org.springframework.ai.ollama.api.common.OllamaApiConstants;
|
||||
import org.springframework.ai.ollama.management.ModelManagementOptions;
|
||||
import org.springframework.ai.ollama.management.OllamaModelManager;
|
||||
import org.springframework.ai.ollama.management.PullModelStrategy;
|
||||
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
@@ -129,27 +134,32 @@ public class OllamaChatModel implements ChatModel {
|
||||
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
|
||||
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
|
||||
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
|
||||
new DefaultToolExecutionEligibilityPredicate());
|
||||
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
|
||||
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
|
||||
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
|
||||
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
|
||||
|
||||
Assert.notNull(ollamaApi, "ollamaApi must not be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions must not be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
|
||||
Assert.notNull(observationRegistry, "observationRegistry must not be null");
|
||||
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
|
||||
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null");
|
||||
Assert.notNull(retryTemplate, "retryTemplate must not be null");
|
||||
this.chatApi = ollamaApi;
|
||||
this.defaultOptions = defaultOptions;
|
||||
this.toolCallingManager = toolCallingManager;
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
|
||||
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
|
||||
this.retryTemplate = retryTemplate;
|
||||
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
|
||||
}
|
||||
|
||||
@@ -237,7 +247,7 @@ public class OllamaChatModel implements ChatModel {
|
||||
this.observationRegistry)
|
||||
.observe(() -> {
|
||||
|
||||
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
|
||||
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
|
||||
|
||||
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
|
||||
: ollamaResponse.message()
|
||||
@@ -540,6 +550,8 @@ public class OllamaChatModel implements ChatModel {
|
||||
|
||||
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
|
||||
|
||||
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
@@ -574,13 +586,20 @@ public class OllamaChatModel implements ChatModel {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder retryTemplate(RetryTemplate retryTemplate) {
|
||||
this.retryTemplate = retryTemplate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatModel build() {
|
||||
if (this.toolCallingManager != null) {
|
||||
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
|
||||
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
|
||||
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
|
||||
this.retryTemplate);
|
||||
}
|
||||
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
|
||||
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
|
||||
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
|
||||
this.retryTemplate);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ import org.springframework.web.reactive.function.client.WebClient;
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Jonghoon Park
|
||||
* @author Alexandros Pappas
|
||||
* @since 0.8.0
|
||||
*/
|
||||
// @formatter:off
|
||||
@@ -64,6 +65,9 @@ public final class OllamaApi {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OllamaApi.class);
|
||||
|
||||
|
||||
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
private final WebClient webClient;
|
||||
@@ -77,11 +81,13 @@ public final class OllamaApi {
|
||||
*/
|
||||
private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
|
||||
Consumer<HttpHeaders> defaultHeaders = headers -> {
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
|
||||
};
|
||||
|
||||
|
||||
this.restClient = restClientBuilder
|
||||
.clone()
|
||||
.baseUrl(baseUrl)
|
||||
@@ -89,6 +95,7 @@ public final class OllamaApi {
|
||||
.defaultStatusHandler(responseErrorHandler)
|
||||
.build();
|
||||
|
||||
|
||||
this.webClient = webClientBuilder
|
||||
.clone()
|
||||
.baseUrl(baseUrl)
|
||||
|
||||
@@ -36,6 +36,7 @@ import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -120,6 +121,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
|
||||
return OllamaChatModel.builder()
|
||||
.ollamaApi(ollamaApi)
|
||||
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -52,8 +52,12 @@ import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.management.ModelManagementOptions;
|
||||
import org.springframework.ai.ollama.management.OllamaModelManager;
|
||||
import org.springframework.ai.ollama.management.PullModelStrategy;
|
||||
|
||||
import org.springframework.ai.support.ToolCallbacks;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -371,6 +375,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
|
||||
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
|
||||
.additionalModels(List.of(ADDITIONAL_MODEL))
|
||||
.build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.ollama;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -27,11 +28,17 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.retry.TransientAiException;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.retry.RetryCallback;
|
||||
import org.springframework.retry.RetryContext;
|
||||
import org.springframework.retry.RetryListener;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
@@ -86,9 +93,23 @@ class OllamaChatModelMultimodalIT extends BaseOllamaIT {
|
||||
|
||||
@Bean
|
||||
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
|
||||
RetryTemplate retryTemplate = RetryTemplate.builder()
|
||||
.maxAttempts(1)
|
||||
.retryOn(TransientAiException.class)
|
||||
.fixedBackoff(Duration.ofSeconds(1))
|
||||
.withListener(new RetryListener() {
|
||||
|
||||
@Override
|
||||
public <T extends Object, E extends Throwable> void onError(RetryContext context,
|
||||
RetryCallback<T, E> callback, Throwable throwable) {
|
||||
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
|
||||
}
|
||||
})
|
||||
.build();
|
||||
return OllamaChatModel.builder()
|
||||
.ollamaApi(ollamaApi)
|
||||
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
|
||||
.retryTemplate(retryTemplate)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -47,6 +48,7 @@ import static org.springframework.ai.chat.observation.ChatModelObservationDocume
|
||||
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
|
||||
public class OllamaChatModelObservationIT extends BaseOllamaIT {
|
||||
@@ -169,7 +171,11 @@ public class OllamaChatModelObservationIT extends BaseOllamaIT {
|
||||
|
||||
@Bean
|
||||
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
|
||||
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
|
||||
return OllamaChatModel.builder()
|
||||
.ollamaApi(ollamaApi)
|
||||
.observationRegistry(observationRegistry)
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.management.ModelManagementOptions;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
@@ -82,6 +83,7 @@ class OllamaChatModelTests {
|
||||
() -> OllamaChatModel.builder()
|
||||
.ollamaApi(this.ollamaApi)
|
||||
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.modelManagementOptions(null)
|
||||
.build());
|
||||
assertEquals("modelManagementOptions must not be null", exception.getMessage());
|
||||
|
||||
@@ -28,18 +28,21 @@ import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.definition.DefaultToolDefinition;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
class OllamaChatRequestTests {
|
||||
|
||||
OllamaChatModel chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(OllamaApi.builder().build())
|
||||
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
|
||||
@Test
|
||||
@@ -146,6 +149,7 @@ class OllamaChatRequestTests {
|
||||
OllamaChatModel chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(OllamaApi.builder().build())
|
||||
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
|
||||
var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));
|
||||
|
||||
@@ -23,7 +23,7 @@ import org.testcontainers.utility.DockerImageName;
|
||||
*/
|
||||
public final class OllamaImage {
|
||||
|
||||
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2");
|
||||
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.6.7");
|
||||
|
||||
private OllamaImage() {
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package org.springframework.ai.ollama;
|
||||
|
||||
import java.time.Instant;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.retry.TransientAiException;
|
||||
import org.springframework.retry.RetryCallback;
|
||||
import org.springframework.retry.RetryContext;
|
||||
import org.springframework.retry.RetryListener;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.isA;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* Tests for the OllamaRetryTests class.
|
||||
*
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class OllamaRetryTests {
|
||||
|
||||
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
|
||||
|
||||
private TestRetryListener retryListener;
|
||||
|
||||
private RetryTemplate retryTemplate;
|
||||
|
||||
@Mock
|
||||
private OllamaApi ollamaApi;
|
||||
|
||||
private OllamaChatModel chatModel;
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE;
|
||||
this.retryListener = new TestRetryListener();
|
||||
this.retryTemplate.registerListener(this.retryListener);
|
||||
|
||||
this.chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(this.ollamaApi)
|
||||
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
|
||||
.retryTemplate(this.retryTemplate)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void ollamaChatTransientError() {
|
||||
String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?";
|
||||
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
|
||||
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Response").build(), null, true,
|
||||
null, null, null, null, null, null);
|
||||
|
||||
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
|
||||
.thenThrow(new TransientAiException("Transient Error 1"))
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(expectedChatResponse);
|
||||
|
||||
var result = this.chatModel.call(new Prompt(promptText));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.getResult().getOutput().getText()).isSameAs("Response");
|
||||
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
|
||||
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
|
||||
}
|
||||
|
||||
private static class TestRetryListener implements RetryListener {
|
||||
|
||||
int onErrorRetryCount = 0;
|
||||
|
||||
int onSuccessRetryCount = 0;
|
||||
|
||||
@Override
|
||||
public <T, E extends Throwable> void onSuccess(RetryContext context, RetryCallback<T, E> callback, T result) {
|
||||
this.onSuccessRetryCount = context.getRetryCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
|
||||
Throwable throwable) {
|
||||
this.onErrorRetryCount = context.getRetryCount();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.autoconfigure.ollama;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
|
||||
import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
||||
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.OllamaEmbeddingModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.management.ModelManagementOptions;
|
||||
import org.springframework.ai.ollama.management.PullModelStrategy;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.web.client.RestClient;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
|
||||
/**
|
||||
* {@link AutoConfiguration Auto-configuration} for Ollama Chat Client.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Eddú Meléndez
|
||||
* @author Thomas Vitale
|
||||
* @since 0.8.0
|
||||
*/
|
||||
@AutoConfiguration(after = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class })
|
||||
@ConditionalOnClass(OllamaApi.class)
|
||||
@EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class,
|
||||
OllamaConnectionProperties.class, OllamaInitializationProperties.class })
|
||||
@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class,
|
||||
WebClientAutoConfiguration.class })
|
||||
public class OllamaAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean(OllamaConnectionDetails.class)
|
||||
public PropertiesOllamaConnectionDetails ollamaConnectionDetails(OllamaConnectionProperties properties) {
|
||||
return new PropertiesOllamaConnectionDetails(properties);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
|
||||
ObjectProvider<RestClient.Builder> restClientBuilderProvider,
|
||||
ObjectProvider<WebClient.Builder> webClientBuilderProvider) {
|
||||
return new OllamaApi(connectionDetails.getBaseUrl(),
|
||||
restClientBuilderProvider.getIfAvailable(RestClient::builder),
|
||||
webClientBuilderProvider.getIfAvailable(WebClient::builder));
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
|
||||
matchIfMissing = true)
|
||||
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
|
||||
OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager,
|
||||
ObjectProvider<ObservationRegistry> observationRegistry,
|
||||
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
|
||||
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
|
||||
: PullModelStrategy.NEVER;
|
||||
|
||||
var chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(ollamaApi)
|
||||
.defaultOptions(properties.getOptions())
|
||||
.toolCallingManager(toolCallingManager)
|
||||
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
|
||||
.modelManagementOptions(
|
||||
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
|
||||
initProperties.getTimeout(), initProperties.getMaxRetries()))
|
||||
.retryTemplate(retryTemplate)
|
||||
.build();
|
||||
|
||||
observationConvention.ifAvailable(chatModel::setObservationConvention);
|
||||
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(prefix = OllamaEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
|
||||
matchIfMissing = true)
|
||||
public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
|
||||
OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
|
||||
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
|
||||
var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
|
||||
? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;
|
||||
|
||||
var embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.ollamaApi(ollamaApi)
|
||||
.defaultOptions(properties.getOptions())
|
||||
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
|
||||
.modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
|
||||
initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
|
||||
initProperties.getMaxRetries()))
|
||||
.build();
|
||||
|
||||
observationConvention.ifAvailable(embeddingModel::setObservationConvention);
|
||||
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
|
||||
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
|
||||
manager.setApplicationContext(context);
|
||||
return manager;
|
||||
}
|
||||
|
||||
static class PropertiesOllamaConnectionDetails implements OllamaConnectionDetails {
|
||||
|
||||
private final OllamaConnectionProperties properties;
|
||||
|
||||
PropertiesOllamaConnectionDetails(OllamaConnectionProperties properties) {
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBaseUrl() {
|
||||
return this.properties.getBaseUrl();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user