Make chat and embedding paths configurable

This commit is contained in:
Mark Pollack
2024-07-19 15:11:04 -04:00
parent 60308eab77
commit e2301157d5
4 changed files with 80 additions and 10 deletions

View File

@@ -23,6 +23,7 @@ import java.util.function.Predicate;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.micrometer.common.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -56,6 +57,10 @@ public class OpenAiApi {
public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue();
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
private final String completionsPath;
private final String embeddingsPath;
private final RestClient restClient;
private final WebClient webClient;
@@ -99,7 +104,24 @@ public class OpenAiApi {
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, openAiToken, "/v1/chat/completions", "/v1/embeddings",
restClientBuilder, webClientBuilder, responseErrorHandler);
}
/**
* Create a new chat completion api.
*
* @param baseUrl api base URL.
* @param openAiToken OpenAI apiKey.
* @param restClientBuilder RestClient builder.
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, String openAiToken, String completionsPath, String embeddingsPath,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Assert.hasText(completionsPath, "Completions Path must not be null");
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
this.completionsPath = completionsPath;
this.embeddingsPath = embeddingsPath;
this.restClient = restClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
@@ -812,7 +834,7 @@ public class OpenAiApi {
Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false.");
return this.restClient.post()
.uri("/v1/chat/completions")
.uri(this.completionsPath)
.body(chatRequest)
.retrieve()
.toEntity(ChatCompletion.class);
@@ -834,7 +856,7 @@ public class OpenAiApi {
AtomicBoolean isInsideTool = new AtomicBoolean(false);
return this.webClient.post()
.uri("/v1/chat/completions")
.uri(this.completionsPath)
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.retrieve()
.bodyToFlux(String.class)
@@ -1022,7 +1044,7 @@ public class OpenAiApi {
}
return this.restClient.post()
.uri("/v1/embeddings")
.uri(this.embeddingsPath)
.body(embeddingRequest)
.retrieve()
.toEntity(new ParameterizedTypeReference<>() {

View File

@@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.openai;
import java.util.List;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -70,8 +71,7 @@ public class OpenAiAutoConfiguration {
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {
var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
var openAiApi = openAiApi(chatProperties, commonProperties, restClientBuilder, webClientBuilder,
responseErrorHandler, "chat");
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
@@ -90,17 +90,39 @@ public class OpenAiAutoConfiguration {
WebClient.Builder webClientBuilder, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {
var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
var openAiApi = openAiApi(embeddingProperties, commonProperties, restClientBuilder, webClientBuilder,
responseErrorHandler, "embedding");
return new OpenAiEmbeddingModel(openAiApi, embeddingProperties.getMetadataMode(),
embeddingProperties.getOptions(), retryTemplate);
}
private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectionProperties commonProperties,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler, String modelType) {
ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(chatProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties, modelType);
return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(), chatProperties.getCompletionsPath(),
OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH, restClientBuilder, webClientBuilder,
responseErrorHandler);
}
private OpenAiApi openAiApi(OpenAiEmbeddingProperties embeddingProperties,
OpenAiConnectionProperties commonProperties, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String modelType) {
ResolvedBaseUrlAndApiKey result = getResolvedBaseUrlAndApiKey(embeddingProperties.getBaseUrl(),
embeddingProperties.getApiKey(), commonProperties, modelType);
return new OpenAiApi(result.resolvedBaseUrl(), result.resolvedApiKey(),
OpenAiChatProperties.DEFAULT_COMPLETIONS_PATH, embeddingProperties.getEmbeddingsPath(),
restClientBuilder, webClientBuilder, responseErrorHandler);
}
private static @NotNull ResolvedBaseUrlAndApiKey getResolvedBaseUrlAndApiKey(String baseUrl, String apiKey,
OpenAiConnectionProperties commonProperties, String modelType) {
var commonBaseUrl = commonProperties.getBaseUrl();
var commonApiKey = commonProperties.getApiKey();
String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
Assert.hasText(resolvedBaseUrl,
@@ -111,9 +133,11 @@ public class OpenAiAutoConfiguration {
Assert.hasText(resolvedApiKey,
"OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai."
+ modelType + ".api-key property.");
ResolvedBaseUrlAndApiKey result = new ResolvedBaseUrlAndApiKey(resolvedBaseUrl, resolvedApiKey);
return result;
}
return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, webClientBuilder,
responseErrorHandler);
private record ResolvedBaseUrlAndApiKey(String resolvedBaseUrl, String resolvedApiKey) {
}
@Bean

View File

@@ -28,11 +28,15 @@ public class OpenAiChatProperties extends OpenAiParentProperties {
private static final Double DEFAULT_TEMPERATURE = 0.7;
public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
/**
* Enable OpenAI chat model.
*/
private boolean enabled = true;
private String completionsPath = DEFAULT_COMPLETIONS_PATH;
@NestedConfigurationProperty
private OpenAiChatOptions options = OpenAiChatOptions.builder()
.withModel(DEFAULT_CHAT_MODEL)
@@ -55,4 +59,12 @@ public class OpenAiChatProperties extends OpenAiParentProperties {
this.enabled = enabled;
}
public String getCompletionsPath() {
return completionsPath;
}
public void setCompletionsPath(String completionsPath) {
this.completionsPath = completionsPath;
}
}

View File

@@ -27,6 +27,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";
public static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
/**
* Enable OpenAI embedding model.
*/
@@ -34,6 +36,8 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
private MetadataMode metadataMode = MetadataMode.EMBED;
private String embeddingsPath = DEFAULT_EMBEDDINGS_PATH;
@NestedConfigurationProperty
private OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder()
.withModel(DEFAULT_EMBEDDING_MODEL)
@@ -63,4 +67,12 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties {
this.enabled = enabled;
}
public String getEmbeddingsPath() {
return embeddingsPath;
}
public void setEmbeddingsPath(String embeddingsPath) {
this.embeddingsPath = embeddingsPath;
}
}