diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 283ae5b6e..c70bd1b9a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -23,6 +23,7 @@ import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import io.micrometer.common.util.StringUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -56,6 +57,10 @@ public class OpenAiApi { public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue(); private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + private final String completionsPath; + + private final String embeddingsPath; + private final RestClient restClient; private final WebClient webClient; @@ -99,7 +104,24 @@ public class OpenAiApi { * @param responseErrorHandler Response error handler. */ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + this(baseUrl, openAiToken, "/v1/chat/completions", "/v1/embeddings", + restClientBuilder, webClientBuilder, responseErrorHandler); + } + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param openAiToken OpenAI apiKey. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public OpenAiApi(String baseUrl, String openAiToken, String completionsPath, String embeddingsPath, + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + Assert.hasText(completionsPath, "Completions Path must not be null"); + Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); + this.completionsPath = completionsPath; + this.embeddingsPath = embeddingsPath; this.restClient = restClientBuilder .baseUrl(baseUrl) .defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken)) @@ -812,7 +834,7 @@ public class OpenAiApi { Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); return this.restClient.post() - .uri("/v1/chat/completions") + .uri(this.completionsPath) .body(chatRequest) .retrieve() .toEntity(ChatCompletion.class); @@ -834,7 +856,7 @@ public class OpenAiApi { AtomicBoolean isInsideTool = new AtomicBoolean(false); return this.webClient.post() - .uri("/v1/chat/completions") + .uri(this.completionsPath) .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) @@ -1022,7 +1044,7 @@ public class OpenAiApi { } return this.restClient.post() - .uri("/v1/embeddings") + .uri(this.embeddingsPath) .body(embeddingRequest) .retrieve() .toEntity(new ParameterizedTypeReference<>() { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 6c3340404..b17fe0a78 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.openai; import java.util.List; +import org.jetbrains.annotations.NotNull; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -70,8 +71,7 @@ public class OpenAiAutoConfiguration { FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), - chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder, + var openAiApi = openAiApi(chatProperties, commonProperties, restClientBuilder, webClientBuilder, responseErrorHandler, "chat"); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { @@ -90,17 +90,39 @@ public class OpenAiAutoConfiguration { WebClient.Builder webClientBuilder, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { - var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), - embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder, + var openAiApi = openAiApi(embeddingProperties, commonProperties, restClientBuilder, webClientBuilder, responseErrorHandler, "embedding"); return new OpenAiEmbeddingModel(openAiApi, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions(), retryTemplate); } - private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, + private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectionProperties commonProperties, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String modelType) { + ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(chatProperties.getBaseUrl(), + chatProperties.getApiKey(), commonProperties, modelType); + + return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(), chatProperties.getCompletionsPath(), + OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH, restClientBuilder, webClientBuilder, + responseErrorHandler); + } + + private OpenAiApi openAiApi(OpenAiEmbeddingProperties embeddingProperties, + OpenAiConnectionProperties commonProperties, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String modelType) { + ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(embeddingProperties.getBaseUrl(), + embeddingProperties.getApiKey(), commonProperties, modelType); + + return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(), + OpenAiChatProperties.DEFAULT_COMPLETIONS_PATH, embeddingProperties.getEmbeddingsPath(), + restClientBuilder, webClientBuilder, responseErrorHandler); + } + + private static @NotNull ResolvedBaseUrlAndApiKey getResolvedBaseUrlAndApiKey(String baseUrl, String apiKey, + OpenAiConnectionProperties commonProperties, String modelType) { + var commonBaseUrl = commonProperties.getBaseUrl(); + var commonApiKey = commonProperties.getApiKey(); String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; Assert.hasText(resolvedBaseUrl, @@ -111,9 +133,11 @@ public class OpenAiAutoConfiguration { Assert.hasText(resolvedApiKey, "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." + modelType + ".api-key property."); + ResolvedBaseUrlAndApiKey result = new ResolvedBaseUrlAndApiKey(resolvedBaseUrl, resolvedApiKey); + return result; + } - return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, webClientBuilder, - responseErrorHandler); + private record ResolvedBaseUrlAndApiKey(String resolvedBaseUrl, String resolvedApiKey) { } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index f602f23b1..f1a301cd9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -28,11 +28,15 @@ public class OpenAiChatProperties extends OpenAiParentProperties { private static final Double DEFAULT_TEMPERATURE = 0.7; + public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions"; + /** * Enable OpenAI chat model. */ private boolean enabled = true; + private String completionsPath = DEFAULT_COMPLETIONS_PATH; + @NestedConfigurationProperty private OpenAiChatOptions options = OpenAiChatOptions.builder() .withModel(DEFAULT_CHAT_MODEL) @@ -55,4 +59,12 @@ public class OpenAiChatProperties extends OpenAiParentProperties { this.enabled = enabled; } + public String getCompletionsPath() { + return completionsPath; + } + + public void setCompletionsPath(String completionsPath) { + this.completionsPath = completionsPath; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java index 5901d4013..008a3c18d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java @@ -27,6 +27,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; + public static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings"; + /** * Enable OpenAI embedding model. */ @@ -34,6 +36,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { private MetadataMode metadataMode = MetadataMode.EMBED; + private String embeddingsPath = DEFAULT_EMBEDDINGS_PATH; + @NestedConfigurationProperty private OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder() .withModel(DEFAULT_EMBEDDING_MODEL) @@ -63,4 +67,12 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { this.enabled = enabled; } + public String getEmbeddingsPath() { + return embeddingsPath; + } + + public void setEmbeddingsPath(String embeddingsPath) { + this.embeddingsPath = embeddingsPath; + } + }