Make chat and embedding paths configurable
This commit is contained in:
@@ -23,6 +23,7 @@ import java.util.function.Predicate;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.micrometer.common.util.StringUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
@@ -56,6 +57,10 @@ public class OpenAiApi {
|
||||
public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue();
|
||||
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
|
||||
|
||||
private final String completionsPath;
|
||||
|
||||
private final String embeddingsPath;
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
private final WebClient webClient;
|
||||
@@ -99,7 +104,24 @@ public class OpenAiApi {
|
||||
* @param responseErrorHandler Response error handler.
|
||||
*/
|
||||
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, openAiToken, "/v1/chat/completions", "/v1/embeddings",
|
||||
restClientBuilder, webClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
*
|
||||
* @param baseUrl api base URL.
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
*/
|
||||
public OpenAiApi(String baseUrl, String openAiToken, String completionsPath, String embeddingsPath,
|
||||
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
Assert.hasText(completionsPath, "Completions Path must not be null");
|
||||
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
|
||||
this.completionsPath = completionsPath;
|
||||
this.embeddingsPath = embeddingsPath;
|
||||
this.restClient = restClientBuilder
|
||||
.baseUrl(baseUrl)
|
||||
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
|
||||
@@ -812,7 +834,7 @@ public class OpenAiApi {
|
||||
Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false.");
|
||||
|
||||
return this.restClient.post()
|
||||
.uri("/v1/chat/completions")
|
||||
.uri(this.completionsPath)
|
||||
.body(chatRequest)
|
||||
.retrieve()
|
||||
.toEntity(ChatCompletion.class);
|
||||
@@ -834,7 +856,7 @@ public class OpenAiApi {
|
||||
AtomicBoolean isInsideTool = new AtomicBoolean(false);
|
||||
|
||||
return this.webClient.post()
|
||||
.uri("/v1/chat/completions")
|
||||
.uri(this.completionsPath)
|
||||
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class)
|
||||
@@ -1022,7 +1044,7 @@ public class OpenAiApi {
|
||||
}
|
||||
|
||||
return this.restClient.post()
|
||||
.uri("/v1/embeddings")
|
||||
.uri(this.embeddingsPath)
|
||||
.body(embeddingRequest)
|
||||
.retrieve()
|
||||
.toEntity(new ParameterizedTypeReference<>() {
|
||||
|
||||
@@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.openai;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
@@ -70,8 +71,7 @@ public class OpenAiAutoConfiguration {
|
||||
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
|
||||
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
|
||||
var openAiApi = openAiApi(chatProperties, commonProperties, restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler, "chat");
|
||||
|
||||
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
|
||||
@@ -90,17 +90,39 @@ public class OpenAiAutoConfiguration {
|
||||
WebClient.Builder webClientBuilder, RetryTemplate retryTemplate,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
|
||||
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
|
||||
var openAiApi = openAiApi(embeddingProperties, commonProperties, restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler, "embedding");
|
||||
|
||||
return new OpenAiEmbeddingModel(openAiApi, embeddingProperties.getMetadataMode(),
|
||||
embeddingProperties.getOptions(), retryTemplate);
|
||||
}
|
||||
|
||||
private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
|
||||
private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectionProperties commonProperties,
|
||||
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler, String modelType) {
|
||||
ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(chatProperties.getBaseUrl(),
|
||||
chatProperties.getApiKey(), commonProperties, modelType);
|
||||
|
||||
return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(), chatProperties.getCompletionsPath(),
|
||||
OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH, restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler);
|
||||
}
|
||||
|
||||
private OpenAiApi openAiApi(OpenAiEmbeddingProperties embeddingProperties,
|
||||
OpenAiConnectionProperties commonProperties, RestClient.Builder restClientBuilder,
|
||||
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String modelType) {
|
||||
ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(embeddingProperties.getBaseUrl(),
|
||||
embeddingProperties.getApiKey(), commonProperties, modelType);
|
||||
|
||||
return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(),
|
||||
OpenAiChatProperties.DEFAULT_COMPLETIONS_PATH, embeddingProperties.getEmbeddingsPath(),
|
||||
restClientBuilder, webClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
private static @NotNull ResolvedBaseUrlAndApiKey getResolvedBaseUrlAndApiKey(String baseUrl, String apiKey,
|
||||
OpenAiConnectionProperties commonProperties, String modelType) {
|
||||
var commonBaseUrl = commonProperties.getBaseUrl();
|
||||
var commonApiKey = commonProperties.getApiKey();
|
||||
|
||||
String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
|
||||
Assert.hasText(resolvedBaseUrl,
|
||||
@@ -111,9 +133,11 @@ public class OpenAiAutoConfiguration {
|
||||
Assert.hasText(resolvedApiKey,
|
||||
"OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai."
|
||||
+ modelType + ".api-key property.");
|
||||
ResolvedBaseUrlAndApiKey result = new ResolvedBaseUrlAndApiKey(resolvedBaseUrl, resolvedApiKey);
|
||||
return result;
|
||||
}
|
||||
|
||||
return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler);
|
||||
private record ResolvedBaseUrlAndApiKey(String resolvedBaseUrl, String resolvedApiKey) {
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -28,11 +28,15 @@ public class OpenAiChatProperties extends OpenAiParentProperties {
|
||||
|
||||
private static final Double DEFAULT_TEMPERATURE = 0.7;
|
||||
|
||||
public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
|
||||
|
||||
/**
|
||||
* Enable OpenAI chat model.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
private String completionsPath = DEFAULT_COMPLETIONS_PATH;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private OpenAiChatOptions options = OpenAiChatOptions.builder()
|
||||
.withModel(DEFAULT_CHAT_MODEL)
|
||||
@@ -55,4 +59,12 @@ public class OpenAiChatProperties extends OpenAiParentProperties {
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
public String getCompletionsPath() {
|
||||
return completionsPath;
|
||||
}
|
||||
|
||||
public void setCompletionsPath(String completionsPath) {
|
||||
this.completionsPath = completionsPath;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
|
||||
|
||||
public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";
|
||||
|
||||
public static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
|
||||
|
||||
/**
|
||||
* Enable OpenAI embedding model.
|
||||
*/
|
||||
@@ -34,6 +36,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
|
||||
|
||||
private MetadataMode metadataMode = MetadataMode.EMBED;
|
||||
|
||||
private String embeddingsPath = DEFAULT_EMBEDDINGS_PATH;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder()
|
||||
.withModel(DEFAULT_EMBEDDING_MODEL)
|
||||
@@ -63,4 +67,12 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
public String getEmbeddingsPath() {
|
||||
return embeddingsPath;
|
||||
}
|
||||
|
||||
public void setEmbeddingsPath(String embeddingsPath) {
|
||||
this.embeddingsPath = embeddingsPath;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user