migrate remaining moonshot modules to community repo
This commit is contained in:
@@ -1 +0,0 @@
|
||||
[Moonshot Chat Documentation](https://docs.spring.io/spring-ai/reference/api/chat/moonshot-chat.html)
|
||||
@@ -1,85 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ Copyright 2023-2024 the original author or authors.
|
||||
~
|
||||
~ Licensed under the Apache License, Version 2.0 (the "License");
|
||||
~ you may not use this file except in compliance with the License.
|
||||
~ You may obtain a copy of the License at
|
||||
~
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License.
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<relativePath>../../pom.xml</relativePath>
|
||||
</parent>
|
||||
<artifactId>spring-ai-moonshot</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
<name>Spring AI Moonshot</name>
|
||||
<description>Moonshot support</description>
|
||||
<url>https://github.com/spring-projects/spring-ai</url>
|
||||
|
||||
<scm>
|
||||
<url>https://github.com/spring-projects/spring-ai</url>
|
||||
<connection>git://github.com/spring-projects/spring-ai.git</connection>
|
||||
<developerConnection>git@github.com:spring-projects/spring-ai.git</developerConnection>
|
||||
</scm>
|
||||
|
||||
|
||||
<properties>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<!-- production dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-client-chat</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-retry</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring Framework -->
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context-support</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- test dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-test</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.micrometer</groupId>
|
||||
<artifactId>micrometer-observation-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -1,483 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.metadata.UsageUtils;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.model.MessageAggregator;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationContext;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion.Choice;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionFinishReason;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ChatCompletionFunction;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool;
|
||||
import org.springframework.ai.moonshot.api.MoonshotConstants;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* MoonshotChatModel is a {@link ChatModel} implementation that uses the Moonshot
|
||||
*
|
||||
* @author Geng Rong
|
||||
* @author Alexandros Pappas
|
||||
* @author Ilayaperumal Gopinathan
|
||||
*/
|
||||
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModel.class);
|
||||
|
||||
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
|
||||
|
||||
/**
|
||||
* The default options used for the chat completion requests.
|
||||
*/
|
||||
private final MoonshotChatOptions defaultOptions;
|
||||
|
||||
/**
|
||||
* Low-level access to the Moonshot API.
|
||||
*/
|
||||
private final MoonshotApi moonshotApi;
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
/**
|
||||
* Observation registry used for instrumentation.
|
||||
*/
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
/**
|
||||
* Conventions to use for generating observations.
|
||||
*/
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MoonshotChatModel.
|
||||
* @param moonshotApi The Moonshot instance to be used for interacting with the
|
||||
* Moonshot Chat API.
|
||||
*/
|
||||
public MoonshotChatModel(MoonshotApi moonshotApi) {
|
||||
this(moonshotApi, MoonshotChatOptions.builder().model(MoonshotApi.DEFAULT_CHAT_MODEL).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MoonshotChatModel.
|
||||
* @param moonshotApi The Moonshot instance to be used for interacting with the
|
||||
* Moonshot Chat API.
|
||||
* @param options The MoonshotChatOptions to configure the chat client.
|
||||
*/
|
||||
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options) {
|
||||
this(moonshotApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MoonshotChatModel.
|
||||
* @param moonshotApi The Moonshot instance to be used for interacting with the
|
||||
* Moonshot Chat API.
|
||||
* @param options The MoonshotChatOptions to configure the chat client.
|
||||
* @param functionCallbackResolver The function callback resolver to resolve the
|
||||
* function by its name.
|
||||
* @param retryTemplate The retry template.
|
||||
*/
|
||||
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options,
|
||||
FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
|
||||
this(moonshotApi, options, functionCallbackResolver, List.of(), retryTemplate, ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MoonshotChatModel.
|
||||
* @param moonshotApi The Moonshot instance to be used for interacting with the
|
||||
* Moonshot Chat API.
|
||||
* @param options The MoonshotChatOptions to configure the chat client.
|
||||
* @param functionCallbackResolver resolves the function by its name.
|
||||
* @param toolFunctionCallbacks The tool function callbacks.
|
||||
* @param retryTemplate The retry template.
|
||||
* @param observationRegistry The ObservationRegistry used for instrumentation.
|
||||
*/
|
||||
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options,
|
||||
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
|
||||
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
|
||||
super(functionCallbackResolver, options, toolFunctionCallbacks);
|
||||
Assert.notNull(moonshotApi, "MoonshotApi must not be null");
|
||||
Assert.notNull(options, "Options must not be null");
|
||||
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
|
||||
Assert.isTrue(CollectionUtils.isEmpty(options.getFunctionCallbacks()),
|
||||
"The default function callbacks must be set via the toolFunctionCallbacks constructor parameter");
|
||||
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
|
||||
this.moonshotApi = moonshotApi;
|
||||
this.defaultOptions = options;
|
||||
this.retryTemplate = retryTemplate;
|
||||
this.observationRegistry = observationRegistry;
|
||||
}
|
||||
|
||||
private static Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
|
||||
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
|
||||
: choice.message()
|
||||
.toolCalls()
|
||||
.stream()
|
||||
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
|
||||
toolCall.function().name(), toolCall.function().arguments()))
|
||||
.toList();
|
||||
|
||||
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
|
||||
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
|
||||
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
|
||||
return new Generation(assistantMessage, generationMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return this.internalCall(prompt, null);
|
||||
}
|
||||
|
||||
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
|
||||
ChatCompletionRequest request = createRequest(prompt, false);
|
||||
|
||||
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
|
||||
.prompt(prompt)
|
||||
.provider(MoonshotConstants.PROVIDER_NAME)
|
||||
.requestOptions(buildRequestOptions(request))
|
||||
.build();
|
||||
|
||||
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
|
||||
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
|
||||
this.observationRegistry)
|
||||
.observe(() -> {
|
||||
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
|
||||
.execute(ctx -> this.moonshotApi.chatCompletionEntity(request));
|
||||
|
||||
var chatCompletion = completionEntity.getBody();
|
||||
|
||||
if (chatCompletion == null) {
|
||||
logger.warn("No chat completion returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
|
||||
List<Choice> choices = chatCompletion.choices();
|
||||
if (choices == null) {
|
||||
logger.warn("No choices returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
|
||||
List<Generation> generations = choices.stream().map(choice -> {
|
||||
// @formatter:off
|
||||
Map<String, Object> metadata = Map.of(
|
||||
"id", chatCompletion.id(),
|
||||
"role", choice.message().role() != null ? choice.message().role().name() : "",
|
||||
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
|
||||
);
|
||||
// @formatter:on
|
||||
return buildGeneration(choice, metadata);
|
||||
}).toList();
|
||||
MoonshotApi.Usage usage = completionEntity.getBody().usage();
|
||||
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
|
||||
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
|
||||
ChatResponse chatResponse = new ChatResponse(generations,
|
||||
from(completionEntity.getBody(), cumulativeUsage));
|
||||
|
||||
observationContext.setResponse(chatResponse);
|
||||
|
||||
return chatResponse;
|
||||
});
|
||||
|
||||
if (!isProxyToolCalls(prompt, this.defaultOptions)
|
||||
&& isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
|
||||
MoonshotApi.ChatCompletionFinishReason.STOP.name()))) {
|
||||
var toolCallConversation = handleToolCalls(prompt, response);
|
||||
// Recursively call the call method with the tool call message
|
||||
// conversation that contains the call responses.
|
||||
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
private DefaultUsage getDefaultUsage(MoonshotApi.Usage usage) {
|
||||
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return this.defaultOptions.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return this.internalStream(prompt, null);
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
|
||||
return Flux.deferContextual(contextView -> {
|
||||
ChatCompletionRequest request = createRequest(prompt, true);
|
||||
|
||||
Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
|
||||
.execute(ctx -> this.moonshotApi.chatCompletionStream(request));
|
||||
|
||||
// For chunked responses, only the first chunk contains the choice role.
|
||||
// The rest of the chunks with same ID share the same role.
|
||||
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
|
||||
|
||||
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
|
||||
.prompt(prompt)
|
||||
.provider(MoonshotConstants.PROVIDER_NAME)
|
||||
.requestOptions(buildRequestOptions(request))
|
||||
.build();
|
||||
|
||||
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
|
||||
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
|
||||
this.observationRegistry);
|
||||
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
|
||||
|
||||
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
|
||||
// the function call handling logic.
|
||||
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
|
||||
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
|
||||
try {
|
||||
String id = chatCompletion2.id();
|
||||
|
||||
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
|
||||
if (choice.message().role() != null) {
|
||||
roleMap.putIfAbsent(id, choice.message().role().name());
|
||||
}
|
||||
|
||||
// @formatter:off
|
||||
Map<String, Object> metadata = Map.of(
|
||||
"id", chatCompletion2.id(),
|
||||
"role", roleMap.getOrDefault(id, ""),
|
||||
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
|
||||
);
|
||||
// @formatter:on
|
||||
return buildGeneration(choice, metadata);
|
||||
}).toList();
|
||||
MoonshotApi.Usage usage = chatCompletion2.usage();
|
||||
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
|
||||
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
|
||||
|
||||
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
|
||||
}
|
||||
catch (Exception e) {
|
||||
logger.error("Error processing chat completion", e);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
|
||||
}));
|
||||
|
||||
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
|
||||
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
|
||||
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
|
||||
// FIXME: bounded elastic needs to be used since tool calling
|
||||
// is currently only synchronous
|
||||
return Flux.defer(() -> {
|
||||
var toolCallConversation = handleToolCalls(prompt, response);
|
||||
// Recursively call the stream method with the tool call message
|
||||
// conversation that contains the call responses.
|
||||
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
|
||||
}).subscribeOn(Schedulers.boundedElastic());
|
||||
}
|
||||
return Flux.just(response);
|
||||
})
|
||||
.doOnError(observation::error)
|
||||
.doFinally(signalType -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
|
||||
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
|
||||
});
|
||||
}
|
||||
|
||||
private ChatResponseMetadata from(ChatCompletion result) {
|
||||
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
|
||||
return ChatResponseMetadata.builder()
|
||||
.id(result.id() != null ? result.id() : "")
|
||||
.usage(result.usage() != null ? getDefaultUsage(result.usage()) : new EmptyUsage())
|
||||
.model(result.model() != null ? result.model() : "")
|
||||
.keyValue("created", result.created() != null ? result.created() : 0L)
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
|
||||
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
|
||||
return ChatResponseMetadata.builder()
|
||||
.id(result.id() != null ? result.id() : "")
|
||||
.usage(usage)
|
||||
.model(result.model() != null ? result.model() : "")
|
||||
.keyValue("created", result.created() != null ? result.created() : 0L)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
|
||||
* @param chunk the ChatCompletionChunk to convert
|
||||
* @return the ChatCompletion
|
||||
*/
|
||||
private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
|
||||
List<ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> {
|
||||
ChatCompletionMessage delta = cc.delta();
|
||||
if (delta == null) {
|
||||
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
|
||||
}
|
||||
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
|
||||
}).toList();
|
||||
// Get the usage from the latest choice
|
||||
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
|
||||
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Accessible for testing.
|
||||
*/
|
||||
public MoonshotApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
||||
|
||||
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
|
||||
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
|
||||
Object content = message.getText();
|
||||
return List.of(new ChatCompletionMessage(content,
|
||||
ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
|
||||
}
|
||||
else if (message.getMessageType() == MessageType.ASSISTANT) {
|
||||
var assistantMessage = (AssistantMessage) message;
|
||||
List<ToolCall> toolCalls = null;
|
||||
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
|
||||
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
|
||||
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
|
||||
return new ToolCall(toolCall.id(), toolCall.type(), function);
|
||||
}).toList();
|
||||
}
|
||||
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
|
||||
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
|
||||
}
|
||||
else if (message.getMessageType() == MessageType.TOOL) {
|
||||
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
|
||||
|
||||
toolMessage.getResponses()
|
||||
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
|
||||
|
||||
return toolMessage.getResponses()
|
||||
.stream()
|
||||
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
|
||||
tr.id(), null))
|
||||
.toList();
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
|
||||
}
|
||||
}).flatMap(List::stream).toList();
|
||||
|
||||
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
|
||||
|
||||
Set<String> enabledToolsToUse = new HashSet<>();
|
||||
|
||||
if (prompt.getOptions() != null) {
|
||||
MoonshotChatOptions updatedRuntimeOptions;
|
||||
|
||||
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
|
||||
FunctionCallingOptions.class, MoonshotChatOptions.class);
|
||||
}
|
||||
else {
|
||||
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
|
||||
MoonshotChatOptions.class);
|
||||
}
|
||||
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
|
||||
|
||||
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
|
||||
}
|
||||
|
||||
if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) {
|
||||
enabledToolsToUse.addAll(this.defaultOptions.getFunctions());
|
||||
}
|
||||
|
||||
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
|
||||
|
||||
// Add the enabled functions definitions to the request's tools parameter.
|
||||
if (!CollectionUtils.isEmpty(enabledToolsToUse)) {
|
||||
|
||||
request = ModelOptionsUtils.merge(
|
||||
MoonshotChatOptions.builder().tools(this.getFunctionTools(enabledToolsToUse)).build(), request,
|
||||
ChatCompletionRequest.class);
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private ChatOptions buildRequestOptions(MoonshotApi.ChatCompletionRequest request) {
|
||||
return ChatOptions.builder()
|
||||
.model(request.model())
|
||||
.frequencyPenalty(request.frequencyPenalty())
|
||||
.maxTokens(request.maxTokens())
|
||||
.presencePenalty(request.presencePenalty())
|
||||
.stopSequences(request.stop())
|
||||
.temperature(request.temperature())
|
||||
.topP(request.topP())
|
||||
.build();
|
||||
}
|
||||
|
||||
private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
|
||||
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
|
||||
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
|
||||
functionCallback.getInputTypeSchema());
|
||||
return new FunctionTool(function);
|
||||
}).toList();
|
||||
}
|
||||
|
||||
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
|
||||
this.observationConvention = observationConvention;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,518 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Options for Moonshot chat completions.
|
||||
*
|
||||
* @author Geng Rong
|
||||
* @author Thomas Vitale
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class MoonshotChatOptions implements FunctionCallingOptions {
|
||||
|
||||
/**
|
||||
* ID of the model to use
|
||||
*/
|
||||
private @JsonProperty("model") String model;
|
||||
|
||||
/**
|
||||
* The maximum number of tokens to generate in the chat completion. The total length
|
||||
* of input tokens and generated tokens is limited by the model's context length.
|
||||
*/
|
||||
private @JsonProperty("max_tokens") Integer maxTokens;
|
||||
|
||||
/**
|
||||
* What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will
|
||||
* make the output more random, while lower values like 0.2 will make it more focused
|
||||
* and deterministic. We generally recommend altering this or top_p but not both.
|
||||
*/
|
||||
private @JsonProperty("temperature") Double temperature;
|
||||
|
||||
/**
|
||||
* An alternative to sampling with temperature, called nucleus sampling, where the
|
||||
* model considers the results of the tokens with top_p probability mass. So 0.1 means
|
||||
* only the tokens comprising the top 10% probability mass are considered. We
|
||||
* generally recommend altering this or temperature but not both.
|
||||
*/
|
||||
private @JsonProperty("top_p") Double topP;
|
||||
|
||||
/**
|
||||
* How many chat completion choices to generate for each input message. Note that you
|
||||
* will be charged based on the number of generated tokens across all the choices.
|
||||
* Keep n as 1 to minimize costs.
|
||||
*/
|
||||
private @JsonProperty("n") Integer n;
|
||||
|
||||
/**
|
||||
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether
|
||||
* they appear in the text so far, increasing the model's likelihood to talk about new
|
||||
* topics.
|
||||
*/
|
||||
private @JsonProperty("presence_penalty") Double presencePenalty;
|
||||
|
||||
/**
|
||||
* Number between -2.0 and 2.0. Positive values penalize new tokens based on their
|
||||
* existing frequency in the text so far, decreasing the model's likelihood to repeat
|
||||
* the same line verbatim.
|
||||
*/
|
||||
private @JsonProperty("frequency_penalty") Double frequencyPenalty;
|
||||
|
||||
/**
|
||||
* Up to 5 sequences where the API will stop generating further tokens.
|
||||
*/
|
||||
private @JsonProperty("stop") List<String> stop;
|
||||
|
||||
private @JsonProperty("tools") List<MoonshotApi.FunctionTool> tools;
|
||||
|
||||
/**
|
||||
* Controls which (if any) function is called by the model. none means the model will
|
||||
* not call a function and instead generates a message. auto means the model can pick
|
||||
* between generating a message or calling a function. Specifying a particular
|
||||
* function via {"type: "function", "function": {"name": "my_function"}} forces the
|
||||
* model to call that function. none is the default when no functions are present.
|
||||
* auto is the default if functions are present. Use the
|
||||
* {@link MoonshotApi.ChatCompletionRequest.ToolChoiceBuilder} to create a tool choice
|
||||
* object.
|
||||
*/
|
||||
private @JsonProperty("tool_choice") String toolChoice;
|
||||
|
||||
/**
|
||||
* Moonshot Tool Function Callbacks to register with the ChatModel. For Prompt Options
|
||||
* the functionCallbacks are automatically enabled for the duration of the prompt
|
||||
* execution. For Default Options the functionCallbacks are registered but disabled by
|
||||
* default. Use the enableFunctions to set the functions from the registry to be used
|
||||
* by the ChatModel chat completion requests.
|
||||
*/
|
||||
@JsonIgnore
|
||||
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* List of functions, identified by their names, to configure for function calling in
|
||||
* the chat completion requests. Functions with those names must exist in the
|
||||
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
|
||||
* are automatically enabled for the duration of the prompt execution.
|
||||
*
|
||||
* Note that function enabled with the default options are enabled for all chat
|
||||
* completion requests. This could impact the token count and the billing. If the
|
||||
* functions is set in a prompt options, then the enabled functions are only active
|
||||
* for the duration of this prompt execution.
|
||||
*/
|
||||
@JsonIgnore
|
||||
private Set<String> functions = new HashSet<>();
|
||||
|
||||
/**
|
||||
* A unique identifier representing your end-user, which can help Moonshot to monitor
|
||||
* and detect abuse.
|
||||
*/
|
||||
private @JsonProperty("user") String user;
|
||||
|
||||
@JsonIgnore
|
||||
private Boolean proxyToolCalls;
|
||||
|
||||
@JsonIgnore
|
||||
private Map<String, Object> toolContext;
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionCallback> getFunctionCallbacks() {
|
||||
return this.functionCallbacks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||
this.functionCallbacks = functionCallbacks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getFunctions() {
|
||||
return this.functions;
|
||||
}
|
||||
|
||||
public void setFunctions(Set<String> functionNames) {
|
||||
this.functions = functionNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
public void setModel(String model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getFrequencyPenalty() {
|
||||
return this.frequencyPenalty;
|
||||
}
|
||||
|
||||
public void setFrequencyPenalty(Double frequencyPenalty) {
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getMaxTokens() {
|
||||
return this.maxTokens;
|
||||
}
|
||||
|
||||
public void setMaxTokens(Integer maxTokens) {
|
||||
this.maxTokens = maxTokens;
|
||||
}
|
||||
|
||||
public Integer getN() {
|
||||
return this.n;
|
||||
}
|
||||
|
||||
public void setN(Integer n) {
|
||||
this.n = n;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getPresencePenalty() {
|
||||
return this.presencePenalty;
|
||||
}
|
||||
|
||||
public void setPresencePenalty(Double presencePenalty) {
|
||||
this.presencePenalty = presencePenalty;
|
||||
}
|
||||
|
||||
@Override
|
||||
@JsonIgnore
|
||||
public List<String> getStopSequences() {
|
||||
return getStop();
|
||||
}
|
||||
|
||||
@JsonIgnore
|
||||
public void setStopSequences(List<String> stopSequences) {
|
||||
setStop(stopSequences);
|
||||
}
|
||||
|
||||
public List<String> getStop() {
|
||||
return this.stop;
|
||||
}
|
||||
|
||||
public void setStop(List<String> stop) {
|
||||
this.stop = stop;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getTemperature() {
|
||||
return this.temperature;
|
||||
}
|
||||
|
||||
public void setTemperature(Double temperature) {
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getTopP() {
|
||||
return this.topP;
|
||||
}
|
||||
|
||||
public void setTopP(Double topP) {
|
||||
this.topP = topP;
|
||||
}
|
||||
|
||||
public String getUser() {
|
||||
return this.user;
|
||||
}
|
||||
|
||||
public void setUser(String user) {
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
@Override
|
||||
@JsonIgnore
|
||||
public Integer getTopK() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getProxyToolCalls() {
|
||||
return this.proxyToolCalls;
|
||||
}
|
||||
|
||||
public void setProxyToolCalls(Boolean proxyToolCalls) {
|
||||
this.proxyToolCalls = proxyToolCalls;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getToolContext() {
|
||||
return this.toolContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setToolContext(Map<String, Object> toolContext) {
|
||||
this.toolContext = toolContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MoonshotChatOptions copy() {
|
||||
return builder().model(this.model)
|
||||
.maxTokens(this.maxTokens)
|
||||
.temperature(this.temperature)
|
||||
.topP(this.topP)
|
||||
.N(this.n)
|
||||
.presencePenalty(this.presencePenalty)
|
||||
.frequencyPenalty(this.frequencyPenalty)
|
||||
.stop(this.stop)
|
||||
.user(this.user)
|
||||
.tools(this.tools)
|
||||
.toolChoice(this.toolChoice)
|
||||
.functionCallbacks(this.functionCallbacks)
|
||||
.functions(this.functions)
|
||||
.proxyToolCalls(this.proxyToolCalls)
|
||||
.toolContext(this.toolContext)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + ((this.model == null) ? 0 : this.model.hashCode());
|
||||
result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode());
|
||||
result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode());
|
||||
result = prime * result + ((this.n == null) ? 0 : this.n.hashCode());
|
||||
result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode());
|
||||
result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode());
|
||||
result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode());
|
||||
result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode());
|
||||
result = prime * result + ((this.user == null) ? 0 : this.user.hashCode());
|
||||
result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode());
|
||||
result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MoonshotChatOptions other = (MoonshotChatOptions) obj;
|
||||
if (this.model == null) {
|
||||
if (other.model != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.model.equals(other.model)) {
|
||||
return false;
|
||||
}
|
||||
if (this.frequencyPenalty == null) {
|
||||
if (other.frequencyPenalty != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) {
|
||||
return false;
|
||||
}
|
||||
if (this.maxTokens == null) {
|
||||
if (other.maxTokens != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.maxTokens.equals(other.maxTokens)) {
|
||||
return false;
|
||||
}
|
||||
if (this.n == null) {
|
||||
if (other.n != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.n.equals(other.n)) {
|
||||
return false;
|
||||
}
|
||||
if (this.presencePenalty == null) {
|
||||
if (other.presencePenalty != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.presencePenalty.equals(other.presencePenalty)) {
|
||||
return false;
|
||||
}
|
||||
if (this.stop == null) {
|
||||
if (other.stop != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.stop.equals(other.stop)) {
|
||||
return false;
|
||||
}
|
||||
if (this.temperature == null) {
|
||||
if (other.temperature != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.temperature.equals(other.temperature)) {
|
||||
return false;
|
||||
}
|
||||
if (this.topP == null) {
|
||||
if (other.topP != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.topP.equals(other.topP)) {
|
||||
return false;
|
||||
}
|
||||
if (this.user == null) {
|
||||
return other.user == null;
|
||||
}
|
||||
else if (!this.user.equals(other.user)) {
|
||||
return false;
|
||||
}
|
||||
if (this.proxyToolCalls == null) {
|
||||
return other.proxyToolCalls == null;
|
||||
}
|
||||
else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) {
|
||||
return false;
|
||||
}
|
||||
if (this.toolContext == null) {
|
||||
return other.toolContext == null;
|
||||
}
|
||||
else if (!this.toolContext.equals(other.toolContext)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final MoonshotChatOptions options = new MoonshotChatOptions();
|
||||
|
||||
public Builder model(String model) {
|
||||
this.options.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder maxTokens(Integer maxTokens) {
|
||||
this.options.maxTokens = maxTokens;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder temperature(Double temperature) {
|
||||
this.options.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder topP(Double topP) {
|
||||
this.options.topP = topP;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder N(Integer n) {
|
||||
this.options.n = n;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder presencePenalty(Double presencePenalty) {
|
||||
this.options.presencePenalty = presencePenalty;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder frequencyPenalty(Double frequencyPenalty) {
|
||||
this.options.frequencyPenalty = frequencyPenalty;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder stop(List<String> stop) {
|
||||
this.options.stop = stop;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder user(String user) {
|
||||
this.options.user = user;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder tools(List<MoonshotApi.FunctionTool> tools) {
|
||||
this.options.tools = tools;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder toolChoice(String toolChoice) {
|
||||
this.options.toolChoice = toolChoice;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||
this.options.functionCallbacks = functionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder functions(Set<String> functionNames) {
|
||||
Assert.notNull(functionNames, "Function names must not be null");
|
||||
this.options.functions = functionNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder function(String functionName) {
|
||||
Assert.hasText(functionName, "Function name must not be empty");
|
||||
if (this.options.functions == null) {
|
||||
this.options.functions = new HashSet<>();
|
||||
}
|
||||
this.options.functions.add(functionName);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder proxyToolCalls(Boolean proxyToolCalls) {
|
||||
this.options.proxyToolCalls = proxyToolCalls;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder toolContext(Map<String, Object> toolContext) {
|
||||
if (this.options.toolContext == null) {
|
||||
this.options.toolContext = toolContext;
|
||||
}
|
||||
else {
|
||||
this.options.toolContext.putAll(toolContext);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public MoonshotChatOptions build() {
|
||||
return this.options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.aot;
|
||||
|
||||
import org.springframework.aot.hint.MemberCategory;
|
||||
import org.springframework.aot.hint.RuntimeHints;
|
||||
import org.springframework.aot.hint.RuntimeHintsRegistrar;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
|
||||
|
||||
/**
|
||||
* The MoonshotRuntimeHints class is responsible for registering runtime hints for
|
||||
* Moonshot API classes.
|
||||
*
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public class MoonshotRuntimeHints implements RuntimeHintsRegistrar {
|
||||
|
||||
@Override
|
||||
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
|
||||
var mcs = MemberCategory.values();
|
||||
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot")) {
|
||||
hints.reflection().registerType(tr, mcs);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,645 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.ai.model.ChatModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
|
||||
/**
|
||||
* Single-class, Java Client library for Moonshot platform. Provides implementation for
|
||||
* the <a href="https://platform.moonshot.cn/docs/api-reference">Chat Completion</a> APIs.
|
||||
* <p>
|
||||
* Implements <b>Synchronous</b> and <b>Streaming</b> chat completion.
|
||||
* </p>
|
||||
*
|
||||
* @author Geng Rong
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class MoonshotApi {
|
||||
|
||||
public static final String DEFAULT_CHAT_MODEL = ChatModel.MOONSHOT_V1_8K.getValue();
|
||||
|
||||
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
private final MoonshotStreamFunctionCallingHelper chunkMerger = new MoonshotStreamFunctionCallingHelper();
|
||||
|
||||
/**
|
||||
* Create a new client api with DEFAULT_BASE_URL
|
||||
* @param moonshotApiKey Moonshot api Key.
|
||||
*/
|
||||
public MoonshotApi(String moonshotApiKey) {
|
||||
this(org.springframework.ai.moonshot.api.MoonshotConstants.DEFAULT_BASE_URL, moonshotApiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new client api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param moonshotApiKey Moonshot api Key.
|
||||
*/
|
||||
public MoonshotApi(String baseUrl, String moonshotApiKey) {
|
||||
this(baseUrl, moonshotApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new client api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param moonshotApiKey Moonshot api Key.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
*/
|
||||
public MoonshotApi(String baseUrl, String moonshotApiKey, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
|
||||
headers.setBearerAuth(moonshotApiKey);
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
};
|
||||
|
||||
this.restClient = restClientBuilder.baseUrl(baseUrl)
|
||||
.defaultHeaders(jsonContentHeaders)
|
||||
.defaultStatusHandler(responseErrorHandler)
|
||||
.build();
|
||||
|
||||
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a model response for the given chat conversation.
|
||||
* @param chatRequest The chat completion request.
|
||||
* @return Entity response with {@link ChatCompletion} as a body and HTTP status code
|
||||
* and headers.
|
||||
*/
|
||||
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) {
|
||||
|
||||
Assert.notNull(chatRequest, "The request body can not be null.");
|
||||
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
|
||||
|
||||
return this.restClient.post()
|
||||
.uri("/v1/chat/completions")
|
||||
.body(chatRequest)
|
||||
.retrieve()
|
||||
.toEntity(ChatCompletion.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a streaming chat response for the given chat conversation.
|
||||
* @param chatRequest The chat completion request. Must have the stream property set
|
||||
* to true.
|
||||
* @return Returns a {@link Flux} stream from chat completion chunks.
|
||||
*/
|
||||
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest) {
|
||||
Assert.notNull(chatRequest, "The request body can not be null.");
|
||||
Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true.");
|
||||
AtomicBoolean isInsideTool = new AtomicBoolean(false);
|
||||
|
||||
return this.webClient.post()
|
||||
.uri("/v1/chat/completions")
|
||||
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
|
||||
.retrieve()
|
||||
.bodyToFlux(String.class)
|
||||
// cancels the flux stream after the "[DONE]" is received.
|
||||
.takeUntil(SSE_DONE_PREDICATE)
|
||||
// filters out the "[DONE]" message.
|
||||
.filter(SSE_DONE_PREDICATE.negate())
|
||||
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
|
||||
// Detect is the chunk is part of a streaming function call.
|
||||
.map(chunk -> {
|
||||
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
|
||||
isInsideTool.set(true);
|
||||
}
|
||||
return chunk;
|
||||
})
|
||||
// Group all chunks belonging to the same function call.
|
||||
// Flux<ChatCompletionChunk> -> Flux<Flux<ChatCompletionChunk>>
|
||||
.windowUntil(chunk -> {
|
||||
if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
|
||||
isInsideTool.set(false);
|
||||
return true;
|
||||
}
|
||||
return !isInsideTool.get();
|
||||
})
|
||||
// Merging the window chunks into a single chunk.
|
||||
// Reduce the inner Flux<ChatCompletionChunk> window into a single
|
||||
// Mono<ChatCompletionChunk>,
|
||||
// Flux<Flux<ChatCompletionChunk>> -> Flux<Mono<ChatCompletionChunk>>
|
||||
.concatMapIterable(window -> {
|
||||
Mono<ChatCompletionChunk> monoChunk = window.reduce(
|
||||
new ChatCompletionChunk(null, null, null, null, null),
|
||||
(previous, current) -> this.chunkMerger.merge(previous, current));
|
||||
return List.of(monoChunk);
|
||||
})
|
||||
// Flux<Mono<ChatCompletionChunk>> -> Flux<ChatCompletionChunk>
|
||||
.flatMap(mono -> mono);
|
||||
}
|
||||
|
||||
/**
|
||||
* The reason the model stopped generating tokens.
|
||||
*/
|
||||
public enum ChatCompletionFinishReason {
|
||||
|
||||
/**
|
||||
* The model hit a natural stop point or a provided stop sequence.
|
||||
*/
|
||||
@JsonProperty("stop")
|
||||
STOP,
|
||||
/**
|
||||
* The maximum number of tokens specified in the request was reached.
|
||||
*/
|
||||
@JsonProperty("length")
|
||||
LENGTH,
|
||||
/**
|
||||
* The content was omitted due to a flag from our content filters.
|
||||
*/
|
||||
@JsonProperty("content_filter")
|
||||
CONTENT_FILTER,
|
||||
/**
|
||||
* The model called a tool.
|
||||
*/
|
||||
@JsonProperty("tool_calls")
|
||||
TOOL_CALLS,
|
||||
/**
|
||||
* Only for compatibility with Mistral AI API.
|
||||
*/
|
||||
@JsonProperty("tool_call")
|
||||
TOOL_CALL
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Moonshot Chat Completion Models:
|
||||
*
|
||||
* <ul>
|
||||
* <li><b>MOONSHOT_V1_8K</b> - moonshot-v1-8k</li>
|
||||
* <li><b>MOONSHOT_V1_32K</b> - moonshot-v1-32k</li>
|
||||
* <li><b>MOONSHOT_V1_128K</b> - moonshot-v1-128k</li>
|
||||
* </ul>
|
||||
*/
|
||||
public enum ChatModel implements ChatModelDescription {
|
||||
|
||||
// @formatter:off
|
||||
MOONSHOT_V1_8K("moonshot-v1-8k"),
|
||||
MOONSHOT_V1_32K("moonshot-v1-32k"),
|
||||
MOONSHOT_V1_128K("moonshot-v1-128k");
|
||||
// @formatter:on
|
||||
|
||||
private final String value;
|
||||
|
||||
ChatModel(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Usage statistics.
|
||||
*
|
||||
* @param promptTokens Number of tokens in the prompt.
|
||||
* @param totalTokens Total number of tokens used in the request (prompt +
|
||||
* completion).
|
||||
* @param completionTokens Number of tokens in the generated completion. Only
|
||||
* applicable for completion requests.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record Usage(
|
||||
// @formatter:off
|
||||
@JsonProperty("prompt_tokens") Integer promptTokens,
|
||||
@JsonProperty("total_tokens") Integer totalTokens,
|
||||
@JsonProperty("completion_tokens") Integer completionTokens) {
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a model response for the given chat conversation.
|
||||
*
|
||||
* @param model ID of the model to use.
|
||||
* @param messages A list of messages comprising the conversation so far.
|
||||
* @param maxTokens The maximum number of tokens to generate in the chat completion.
|
||||
* The total length of input tokens and generated tokens is limited by the model's
|
||||
* context length.
|
||||
* @param temperature What sampling temperature to use, between 0 and 1. Higher values
|
||||
* like 0.8 will make the output more random, while lower values like 0.2 will make it
|
||||
* more focused and deterministic. We generally recommend altering this or top_p but
|
||||
* not both.
|
||||
* @param topP An alternative to sampling with temperature, called nucleus sampling,
|
||||
* where the model considers the results of the tokens with top_p probability mass. So
|
||||
* 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||
* We generally recommend altering this or temperature but not both.
|
||||
* @param n How many chat completion choices to generate for each input message. Note
|
||||
* that you will be charged based on the number of generated tokens across all the
|
||||
* choices. Keep n as 1 to minimize costs.
|
||||
* @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new
|
||||
* tokens based on whether they appear in the text so far, increasing the model's
|
||||
* likelihood to talk about new topics.
|
||||
* @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new
|
||||
* tokens based on their existing frequency in the text so far, decreasing the model's
|
||||
* likelihood to repeat the same line verbatim.
|
||||
* @param stop Up to 5 sequences where the API will stop generating further tokens.
|
||||
* @param stream If set, partial message deltas will be sent.Tokens will be sent as
|
||||
* data-only server-sent events as they become available, with the stream terminated
|
||||
* by a data: [DONE] message.
|
||||
* @param tools A list of tools the model may call. Currently, only functions are
|
||||
* supported as a tool.
|
||||
* @param toolChoice Controls which (if any) function is called by the model.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChatCompletionRequest(
|
||||
// @formatter:off
|
||||
@JsonProperty("messages") List<ChatCompletionMessage> messages,
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("max_tokens") Integer maxTokens,
|
||||
@JsonProperty("temperature") Double temperature,
|
||||
@JsonProperty("top_p") Double topP,
|
||||
@JsonProperty("n") Integer n,
|
||||
@JsonProperty("frequency_penalty") Double frequencyPenalty,
|
||||
@JsonProperty("presence_penalty") Double presencePenalty,
|
||||
@JsonProperty("stop") List<String> stop,
|
||||
@JsonProperty("stream") Boolean stream,
|
||||
@JsonProperty("tools") List<FunctionTool> tools,
|
||||
@JsonProperty("tool_choice") Object toolChoice) {
|
||||
// @formatter:on
|
||||
|
||||
/**
|
||||
* Shortcut constructor for a chat completion request with the given messages and
|
||||
* model.
|
||||
* @param messages The prompt(s) to generate completions for, encoded as a list of
|
||||
* dict with role and content. The first prompt role should be user or system.
|
||||
* @param model ID of the model to use.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model) {
|
||||
this(messages, model, null, 0.3, 1.0, null, null, null, null, false, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut constructor for a chat completion request with the given messages,
|
||||
* model and temperature.
|
||||
* @param messages The prompt(s) to generate completions for, encoded as a list of
|
||||
* dict with role and content. The first prompt role should be user or system.
|
||||
* @param model ID of the model to use.
|
||||
* @param temperature What sampling temperature to use, between 0.0 and 1.0.
|
||||
* @param stream Whether to stream back partial progress. If set, tokens will be
|
||||
* sent
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature,
|
||||
boolean stream) {
|
||||
this(messages, model, null, temperature, 1.0, null, null, null, null, stream, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut constructor for a chat completion request with the given messages,
|
||||
* model and temperature.
|
||||
* @param messages The prompt(s) to generate completions for, encoded as a list of
|
||||
* dict with role and content. The first prompt role should be user or system.
|
||||
* @param model ID of the model to use.
|
||||
* @param temperature What sampling temperature to use, between 0.0 and 1.0.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
|
||||
this(messages, model, null, temperature, 1.0, null, null, null, null, false, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut constructor for a chat completion request with the given messages,
|
||||
* model, tools and tool choice. Streaming is set to false, temperature to 0.8 and
|
||||
* all other parameters are null.
|
||||
* @param messages A list of messages comprising the conversation so far.
|
||||
* @param model ID of the model to use.
|
||||
* @param tools A list of tools the model may call. Currently, only functions are
|
||||
* supported as a tool.
|
||||
* @param toolChoice Controls which (if any) function is called by the model.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools,
|
||||
Object toolChoice) {
|
||||
this(messages, model, null, null, 1.0, null, null, null, null, false, tools, toolChoice);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut constructor for a chat completion request with the given messages and
|
||||
* stream.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
|
||||
this(messages, DEFAULT_CHAT_MODEL, null, 0.7, 1.0, null, null, null, null, stream, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper factory that creates a tool_choice of type 'none', 'auto' or selected
|
||||
* function by name.
|
||||
*/
|
||||
public static class ToolChoiceBuilder {
|
||||
|
||||
/**
|
||||
* Model can pick between generating a message or calling a function.
|
||||
*/
|
||||
public static final String AUTO = "auto";
|
||||
|
||||
/**
|
||||
* Model will not call a function and instead generates a message
|
||||
*/
|
||||
public static final String NONE = "none";
|
||||
|
||||
/**
|
||||
* Specifying a particular function forces the model to call that function.
|
||||
*/
|
||||
public static Object function(String functionName) {
|
||||
return Map.of("type", "function", "function", Map.of("name", functionName));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Message comprising the conversation.
|
||||
*
|
||||
* @param rawContent The raw contents of the message.
|
||||
* @param role The role of the message's author. Could be one of the {@link Role}
|
||||
* types.
|
||||
* @param name The name of the message's author.
|
||||
* @param toolCallId The ID of the tool call associated with the message.
|
||||
* @param toolCalls The list of tool calls associated with the message.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChatCompletionMessage(
|
||||
// @formatter:off
|
||||
@JsonProperty("content") Object rawContent,
|
||||
@JsonProperty("role") Role role,
|
||||
@JsonProperty("name") String name,
|
||||
@JsonProperty("tool_call_id") String toolCallId,
|
||||
@JsonProperty("tool_calls") List<ToolCall> toolCalls
|
||||
// @formatter:on
|
||||
) {
|
||||
|
||||
/**
|
||||
* Create a chat completion message with the given content and role. All other
|
||||
* fields are null.
|
||||
* @param content The contents of the message.
|
||||
* @param role The role of the author of this message.
|
||||
*/
|
||||
public ChatCompletionMessage(Object content, Role role) {
|
||||
this(content, role, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get message content as String.
|
||||
*/
|
||||
public String content() {
|
||||
if (this.rawContent == null) {
|
||||
return null;
|
||||
}
|
||||
if (this.rawContent instanceof String text) {
|
||||
return text;
|
||||
}
|
||||
throw new IllegalStateException("The content is not a string!");
|
||||
}
|
||||
|
||||
/**
|
||||
* The role of the author of this message. NOTE: Moonshot expects the system
|
||||
* message to be before the user message or will fail with 400 error.
|
||||
*/
|
||||
public enum Role {
|
||||
|
||||
/**
|
||||
* System message.
|
||||
*/
|
||||
@JsonProperty("system")
|
||||
SYSTEM,
|
||||
/**
|
||||
* User message.
|
||||
*/
|
||||
@JsonProperty("user")
|
||||
USER,
|
||||
/**
|
||||
* Assistant message.
|
||||
*/
|
||||
@JsonProperty("assistant")
|
||||
ASSISTANT,
|
||||
/**
|
||||
* Tool message.
|
||||
*/
|
||||
@JsonProperty("tool")
|
||||
TOOL
|
||||
// @formatter:on
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* The relevant tool call.
|
||||
*
|
||||
* @param id The ID of the tool call. This ID must be referenced when you submit
|
||||
* the tool outputs in using the Submit tool outputs to run endpoint.
|
||||
* @param type The type of tool call the output is required for. For now, this is
|
||||
* always function.
|
||||
* @param function The function definition.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type,
|
||||
@JsonProperty("function") ChatCompletionFunction function) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* The function definition.
|
||||
*
|
||||
* @param name The name of the function.
|
||||
* @param arguments The arguments that the model expects you to pass to the
|
||||
* function.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChatCompletionFunction(@JsonProperty("name") String name,
|
||||
@JsonProperty("arguments") String arguments) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a chat completion response returned by model, based on the provided
|
||||
* input.
|
||||
*
|
||||
* @param id A unique identifier for the chat completion.
|
||||
* @param object The object type, which is always chat.completion.
|
||||
* @param created The Unix timestamp (in seconds) of when the chat completion was
|
||||
* created.
|
||||
* @param model The model used for the chat completion.
|
||||
* @param choices A list of chat completion choices.
|
||||
* @param usage Usage statistics for the completion request.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChatCompletion(
|
||||
// @formatter:off
|
||||
@JsonProperty("id") String id,
|
||||
@JsonProperty("object") String object,
|
||||
@JsonProperty("created") Long created,
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("choices") List<Choice> choices,
|
||||
@JsonProperty("usage") Usage usage) {
|
||||
// @formatter:on
|
||||
|
||||
/**
|
||||
* Chat completion choice.
|
||||
*
|
||||
* @param index The index of the choice in the list of choices.
|
||||
* @param message A chat completion message generated by the model.
|
||||
* @param finishReason The reason the model stopped generating tokens.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record Choice(
|
||||
// @formatter:off
|
||||
@JsonProperty("index") Integer index,
|
||||
@JsonProperty("message") ChatCompletionMessage message,
|
||||
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
|
||||
@JsonProperty("usage") Usage usage) {
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a streamed chunk of a chat completion response returned by model, based
|
||||
* on the provided input.
|
||||
*
|
||||
* @param id A unique identifier for the chat completion. Each chunk has the same ID.
|
||||
* @param object The object type, which is always 'chat.completion.chunk'.
|
||||
* @param created The Unix timestamp (in seconds) of when the chat completion was
|
||||
* created. Each chunk has the same timestamp.
|
||||
* @param model The model used for the chat completion.
|
||||
* @param choices A list of chat completion choices. Can be more than one if n is
|
||||
* greater than 1.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChatCompletionChunk(
|
||||
// @formatter:off
|
||||
@JsonProperty("id") String id,
|
||||
@JsonProperty("object") String object,
|
||||
@JsonProperty("created") Long created,
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("choices") List<ChunkChoice> choices) {
|
||||
// @formatter:on
|
||||
|
||||
/**
|
||||
* Chat completion choice.
|
||||
*
|
||||
* @param index The index of the choice in the list of choices.
|
||||
* @param delta A chat completion delta generated by streamed model responses.
|
||||
* @param finishReason The reason the model stopped generating tokens.
|
||||
* @param usage Usage statistics for the completion request.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record ChunkChoice(
|
||||
// @formatter:off
|
||||
@JsonProperty("index") Integer index,
|
||||
@JsonProperty("delta") ChatCompletionMessage delta,
|
||||
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
|
||||
@JsonProperty("usage") Usage usage
|
||||
// @formatter:on
|
||||
) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a tool the model may call. Currently, only functions are supported as a
|
||||
* tool.
|
||||
*
|
||||
* @param type The type of the tool. Currently, only 'function' is supported.
|
||||
* @param function The function definition.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record FunctionTool(@JsonProperty("type") Type type, @JsonProperty("function") Function function) {
|
||||
|
||||
/**
|
||||
* Create a tool of type 'function' and the given function definition.
|
||||
* @param function function definition.
|
||||
*/
|
||||
public FunctionTool(Function function) {
|
||||
this(Type.FUNCTION, function);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool of type 'function' and the given function definition.
|
||||
*/
|
||||
public enum Type {
|
||||
|
||||
/**
|
||||
* Function tool type.
|
||||
*/
|
||||
@JsonProperty("function")
|
||||
FUNCTION
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Function definition.
|
||||
*
|
||||
* @param description A description of what the function does, used by the model
|
||||
* to choose when and how to call the function.
|
||||
* @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or
|
||||
* contain underscores and dashes, with a maximum length of 64.
|
||||
* @param parameters The parameters the functions accepts, described as a JSON
|
||||
* Schema object. To describe a function that accepts no parameters, provide the
|
||||
* value {"type": "object", "properties": {}}.
|
||||
*/
|
||||
public record Function(@JsonProperty("description") String description, @JsonProperty("name") String name,
|
||||
@JsonProperty("parameters") Map<String, Object> parameters) {
|
||||
|
||||
/**
|
||||
* Create tool function definition.
|
||||
* @param description tool function description.
|
||||
* @param name tool function name.
|
||||
* @param jsonSchema tool function schema as json.
|
||||
*/
|
||||
public Function(String description, String name, String jsonSchema) {
|
||||
this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
|
||||
/**
|
||||
* Constants for Moonshot API.
|
||||
*
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public final class MoonshotConstants {
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "https://api.moonshot.cn";
|
||||
|
||||
public static final String PROVIDER_NAME = AiProvider.MOONSHOT.value();
|
||||
|
||||
private MoonshotConstants() {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk.ChunkChoice;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionFinishReason;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ChatCompletionFunction;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Helper class to support Streaming function calling. It can merge the streamed
|
||||
* ChatCompletionChunk in case of function calling message.
|
||||
*
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public class MoonshotStreamFunctionCallingHelper {
|
||||
|
||||
public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {
|
||||
|
||||
if (previous == null) {
|
||||
return current;
|
||||
}
|
||||
|
||||
String id = (current.id() != null ? current.id() : previous.id());
|
||||
Long created = (current.created() != null ? current.created() : previous.created());
|
||||
String model = (current.model() != null ? current.model() : previous.model());
|
||||
String object = (current.object() != null ? current.object() : previous.object());
|
||||
|
||||
ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
|
||||
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));
|
||||
|
||||
ChunkChoice choice = merge(previousChoice0, currentChoice0);
|
||||
List<ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
|
||||
return new ChatCompletionChunk(id, object, created, model, chunkChoices);
|
||||
}
|
||||
|
||||
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
|
||||
if (previous == null) {
|
||||
return current;
|
||||
}
|
||||
|
||||
ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason()
|
||||
: previous.finishReason());
|
||||
Integer index = (current.index() != null ? current.index() : previous.index());
|
||||
|
||||
MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();
|
||||
|
||||
ChatCompletionMessage message = merge(previous.delta(), current.delta());
|
||||
return new ChunkChoice(index, message, finishReason, usage);
|
||||
}
|
||||
|
||||
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
|
||||
String content = (current.content() != null ? current.content()
|
||||
: "" + ((previous.content() != null) ? previous.content() : ""));
|
||||
Role role = (current.role() != null ? current.role() : previous.role());
|
||||
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
|
||||
String name = (current.name() != null ? current.name() : previous.name());
|
||||
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());
|
||||
|
||||
List<ToolCall> toolCalls = new ArrayList<>();
|
||||
ToolCall lastPreviousTooCall = null;
|
||||
if (previous.toolCalls() != null) {
|
||||
lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
|
||||
if (previous.toolCalls().size() > 1) {
|
||||
toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1));
|
||||
}
|
||||
}
|
||||
if (current.toolCalls() != null) {
|
||||
if (current.toolCalls().size() > 1) {
|
||||
throw new IllegalStateException("Currently only one tool call is supported per message!");
|
||||
}
|
||||
var currentToolCall = current.toolCalls().iterator().next();
|
||||
if (currentToolCall.id() != null) {
|
||||
if (lastPreviousTooCall != null) {
|
||||
toolCalls.add(lastPreviousTooCall);
|
||||
}
|
||||
toolCalls.add(currentToolCall);
|
||||
}
|
||||
else {
|
||||
toolCalls.add(merge(lastPreviousTooCall, currentToolCall));
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (lastPreviousTooCall != null) {
|
||||
toolCalls.add(lastPreviousTooCall);
|
||||
}
|
||||
}
|
||||
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
|
||||
}
|
||||
|
||||
private ToolCall merge(ToolCall previous, ToolCall current) {
|
||||
if (previous == null) {
|
||||
return current;
|
||||
}
|
||||
String id = (current.id() != null ? current.id() : previous.id());
|
||||
String type = (current.type() != null ? current.type() : previous.type());
|
||||
ChatCompletionFunction function = merge(previous.function(), current.function());
|
||||
return new ToolCall(id, type, function);
|
||||
}
|
||||
|
||||
private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {
|
||||
if (previous == null) {
|
||||
return current;
|
||||
}
|
||||
String name = (current.name() != null ? current.name() : previous.name());
|
||||
StringBuilder arguments = new StringBuilder();
|
||||
if (previous.arguments() != null) {
|
||||
arguments.append(previous.arguments());
|
||||
}
|
||||
if (current.arguments() != null) {
|
||||
arguments.append(current.arguments());
|
||||
}
|
||||
return new ChatCompletionFunction(name, arguments.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* @param chatCompletion the ChatCompletionChunk to check
|
||||
* @return true if the ChatCompletionChunk is a streaming tool function call.
|
||||
*/
|
||||
public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) {
|
||||
|
||||
if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var choice = chatCompletion.choices().get(0);
|
||||
if (choice == null || choice.delta() == null) {
|
||||
return false;
|
||||
}
|
||||
return !CollectionUtils.isEmpty(choice.delta().toolCalls());
|
||||
}
|
||||
|
||||
/**
|
||||
* @param chatCompletion the ChatCompletionChunk to check
|
||||
* @return true if the ChatCompletionChunk is a streaming tool function call and it is
|
||||
* the last one.
|
||||
*/
|
||||
public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) {
|
||||
|
||||
if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var choice = chatCompletion.choices().get(0);
|
||||
if (choice == null || choice.delta() == null) {
|
||||
return false;
|
||||
}
|
||||
return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
org.springframework.aot.hint.RuntimeHintsRegistrar=\
|
||||
org.springframework.ai.moonshot.aot.MoonshotRuntimeHints
|
||||
@@ -1,60 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
public class MoonshotChatCompletionRequestTest {
|
||||
|
||||
MoonshotChatModel chatModel = new MoonshotChatModel(new MoonshotApi("test"));
|
||||
|
||||
@Test
|
||||
void chatCompletionDefaultRequestTest() {
|
||||
var request = this.chatModel.createRequest(new Prompt("test content"), false);
|
||||
|
||||
assertThat(request.messages()).hasSize(1);
|
||||
assertThat(request.topP()).isEqualTo(1);
|
||||
assertThat(request.temperature()).isEqualTo(0.7);
|
||||
assertThat(request.maxTokens()).isNull();
|
||||
assertThat(request.stream()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatCompletionRequestWithOptionsTest() {
|
||||
var options = MoonshotChatOptions.builder().temperature(0.5).topP(0.8).build();
|
||||
var request = this.chatModel.createRequest(new Prompt("test content", options), true);
|
||||
|
||||
assertThat(request.messages().size()).isEqualTo(1);
|
||||
assertThat(request.topP()).isEqualTo(0.8);
|
||||
assertThat(request.temperature()).isEqualTo(0.5);
|
||||
assertThat(request.stream()).isTrue();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionFinishReason;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.retry.TransientAiException;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.RetryCallback;
|
||||
import org.springframework.retry.RetryContext;
|
||||
import org.springframework.retry.RetryListener;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.isA;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class MoonshotRetryTests {
|
||||
|
||||
private TestRetryListener retryListener;
|
||||
|
||||
private @Mock MoonshotApi moonshotApi;
|
||||
|
||||
private MoonshotChatModel chatModel;
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE;
|
||||
this.retryListener = new TestRetryListener();
|
||||
retryTemplate.registerListener(this.retryListener);
|
||||
|
||||
this.chatModel = new MoonshotChatModel(this.moonshotApi,
|
||||
MoonshotChatOptions.builder()
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_32K.getValue())
|
||||
.build(),
|
||||
null, retryTemplate);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void moonshotChatTransientError() {
|
||||
|
||||
var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
|
||||
ChatCompletionFinishReason.STOP, null);
|
||||
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
|
||||
List.of(choice), new MoonshotApi.Usage(10, 10, 10));
|
||||
|
||||
given(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
|
||||
.willThrow(new TransientAiException("Transient Error 1"))
|
||||
.willThrow(new TransientAiException("Transient Error 2"))
|
||||
.willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
|
||||
|
||||
var result = this.chatModel.call(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.getResult().getOutput().getText()).isSameAs("Response");
|
||||
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
|
||||
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void moonshotChatNonTransientError() {
|
||||
given(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
|
||||
.willThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void moonshotChatStreamTransientError() {
|
||||
|
||||
var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
|
||||
ChatCompletionFinishReason.STOP, null);
|
||||
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L,
|
||||
"model", List.of(choice));
|
||||
|
||||
given(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
|
||||
.willThrow(new TransientAiException("Transient Error 1"))
|
||||
.willThrow(new TransientAiException("Transient Error 2"))
|
||||
.willReturn(Flux.just(expectedChatCompletion));
|
||||
|
||||
var result = this.chatModel.stream(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.collectList().block().get(0).getResult().getOutput().getText()).isSameAs("Response");
|
||||
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
|
||||
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void moonshotChatStreamNonTransientError() {
|
||||
given(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
|
||||
.willThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block());
|
||||
}
|
||||
|
||||
private static class TestRetryListener implements RetryListener {
|
||||
|
||||
int onErrorRetryCount = 0;
|
||||
|
||||
int onSuccessRetryCount = 0;
|
||||
|
||||
@Override
|
||||
public <T, E extends Throwable> void onSuccess(RetryContext context, RetryCallback<T, E> callback, T result) {
|
||||
this.onSuccessRetryCount = context.getRetryCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
|
||||
Throwable throwable) {
|
||||
this.onErrorRetryCount = context.getRetryCount();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot;
|
||||
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
@SpringBootConfiguration
|
||||
public class MoonshotTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public MoonshotApi moonshotApi() {
|
||||
var apiKey = System.getenv("MOONSHOT_API_KEY");
|
||||
if (!StringUtils.hasText(apiKey)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Missing MOONSHOT_API_KEY environment variable. Please set it to your Moonshot API key.");
|
||||
}
|
||||
return new MoonshotApi(apiKey);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MoonshotChatModel moonshotChatModel(MoonshotApi moonshotApi) {
|
||||
return new MoonshotChatModel(moonshotApi);
|
||||
}
|
||||
|
||||
public void tst() {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.aot;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.aot.hint.RuntimeHints;
|
||||
import org.springframework.aot.hint.TypeReference;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
class MoonshotRuntimeHintsTests {
|
||||
|
||||
@Test
|
||||
void registerHints() {
|
||||
RuntimeHints runtimeHints = new RuntimeHints();
|
||||
MoonshotRuntimeHints moonshotRuntimeHints = new MoonshotRuntimeHints();
|
||||
moonshotRuntimeHints.registerHints(runtimeHints, null);
|
||||
|
||||
Set<TypeReference> jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage("org.springframework.ai.moonshot");
|
||||
|
||||
Set<TypeReference> registeredTypes = new HashSet<>();
|
||||
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
|
||||
|
||||
for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) {
|
||||
assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue();
|
||||
}
|
||||
|
||||
// Check a few more specific ones
|
||||
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletion.class))).isTrue();
|
||||
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionRequest.class))).isTrue();
|
||||
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.ChatCompletionChunk.class))).isTrue();
|
||||
assertThat(registeredTypes.contains(TypeReference.of(MoonshotApi.Usage.class))).isTrue();
|
||||
assertThat(registeredTypes.contains(TypeReference.of(MoonshotChatOptions.class))).isTrue();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public class MockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {
|
||||
|
||||
@Override
|
||||
public Response apply(Request request) {
|
||||
|
||||
double temperature = 0;
|
||||
if (request.location().contains("Paris")) {
|
||||
temperature = 15;
|
||||
}
|
||||
else if (request.location().contains("Tokyo")) {
|
||||
temperature = 10;
|
||||
}
|
||||
else if (request.location().contains("San Francisco")) {
|
||||
temperature = 30;
|
||||
}
|
||||
|
||||
return new Response(temperature, 15, 20, 2, 53, 45, request.unit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Temperature units.
|
||||
*/
|
||||
public enum Unit {
|
||||
|
||||
/**
|
||||
* Celsius.
|
||||
*/
|
||||
C("metric"),
|
||||
/**
|
||||
* Fahrenheit.
|
||||
*/
|
||||
F("imperial");
|
||||
|
||||
/**
|
||||
* Human readable unit name.
|
||||
*/
|
||||
public final String unitName;
|
||||
|
||||
Unit(String text) {
|
||||
this.unitName = text;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Weather Function request.
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
@JsonClassDescription("Weather API request")
|
||||
public record Request(@JsonProperty(required = true,
|
||||
value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location,
|
||||
@JsonProperty("lat") @JsonPropertyDescription("The city latitude") double lat,
|
||||
@JsonProperty("lon") @JsonPropertyDescription("The city longitude") double lon,
|
||||
@JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Weather Function response.
|
||||
*/
|
||||
public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity,
|
||||
Unit unit) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
public class MoonshotApiIT {
|
||||
|
||||
MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY"));
|
||||
|
||||
@Test
|
||||
void chatCompletionEntity() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
ResponseEntity<ChatCompletion> response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest(
|
||||
List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.getBody()).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatCompletionEntityWithSystemMessage() {
|
||||
ChatCompletionMessage userMessage = new ChatCompletionMessage(
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did?", Role.USER);
|
||||
ChatCompletionMessage systemMessage = new ChatCompletionMessage("""
|
||||
You are an AI assistant that helps people find information.
|
||||
Your name is Bob.
|
||||
You should reply to the user's request with your name and also in the style of a pirate.
|
||||
""", Role.SYSTEM);
|
||||
|
||||
ResponseEntity<ChatCompletion> response = this.moonshotApi.chatCompletionEntity(new ChatCompletionRequest(
|
||||
List.of(systemMessage, userMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, false));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.getBody()).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void chatCompletionStream() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
Flux<ChatCompletionChunk> response = this.moonshotApi.chatCompletionStream(new ChatCompletionRequest(
|
||||
List.of(chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.8, true));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.collectList().block()).isNotNull();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.api;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest.ToolChoiceBuilder;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool.Type;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
public class MoonshotApiToolFunctionCallIT {
|
||||
|
||||
private static final FunctionTool FUNCTION_TOOL = new FunctionTool(Type.FUNCTION, new FunctionTool.Function(
|
||||
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"lat": {
|
||||
"type": "number",
|
||||
"description": "The city latitude"
|
||||
},
|
||||
"lon": {
|
||||
"type": "number",
|
||||
"description": "The city longitude"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["C", "F"]
|
||||
}
|
||||
},
|
||||
"required": ["location", "lat", "lon", "unit"]
|
||||
}
|
||||
"""));
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(MoonshotApiToolFunctionCallIT.class);
|
||||
|
||||
private final MockWeatherService weatherService = new MockWeatherService();
|
||||
|
||||
private final MoonshotApi moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY"));
|
||||
|
||||
private static <T> T fromJson(String json, Class<T> targetClass) {
|
||||
try {
|
||||
return new ObjectMapper().readValue(json, targetClass);
|
||||
}
|
||||
catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("null")
|
||||
@Test
|
||||
public void toolFunctionCall() {
|
||||
toolFunctionCall("What's the weather like in San Francisco? Return the temperature in Celsius.",
|
||||
"San Francisco");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void toolFunctionCallChinese() {
|
||||
toolFunctionCall("旧金山、东京和巴黎的气温怎么样? 返回摄氏度的温度", "旧金山");
|
||||
}
|
||||
|
||||
private void toolFunctionCall(String userMessage, String cityName) {
|
||||
// Step 1: send the conversation and available functions to the model
|
||||
var message = new ChatCompletionMessage(userMessage, Role.USER);
|
||||
|
||||
List<ChatCompletionMessage> messages = new ArrayList<>(List.of(message));
|
||||
|
||||
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages,
|
||||
MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), List.of(FUNCTION_TOOL), ToolChoiceBuilder.AUTO);
|
||||
|
||||
ResponseEntity<ChatCompletion> chatCompletion = this.moonshotApi.chatCompletionEntity(chatCompletionRequest);
|
||||
|
||||
assertThat(chatCompletion.getBody()).isNotNull();
|
||||
assertThat(chatCompletion.getBody().choices()).isNotEmpty();
|
||||
|
||||
ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().get(0).message();
|
||||
|
||||
assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT);
|
||||
assertThat(responseMessage.toolCalls()).isNotNull();
|
||||
|
||||
messages.add(responseMessage);
|
||||
|
||||
// Send the info for each function call and function response to the model.
|
||||
for (ToolCall toolCall : responseMessage.toolCalls()) {
|
||||
var functionName = toolCall.function().name();
|
||||
if ("getCurrentWeather".equals(functionName)) {
|
||||
MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(),
|
||||
MockWeatherService.Request.class);
|
||||
|
||||
MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest);
|
||||
|
||||
// extend conversation with function response.
|
||||
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
|
||||
functionName, toolCall.id(), null));
|
||||
}
|
||||
}
|
||||
|
||||
var functionResponseRequest = new ChatCompletionRequest(messages,
|
||||
MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.5);
|
||||
|
||||
ResponseEntity<ChatCompletion> chatCompletion2 = this.moonshotApi.chatCompletionEntity(functionResponseRequest);
|
||||
|
||||
logger.info("Final response: " + chatCompletion2.getBody());
|
||||
|
||||
assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty();
|
||||
|
||||
assertThat(chatCompletion2.getBody().choices().get(0).message().role()).isEqualTo(Role.ASSISTANT);
|
||||
assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains(cityName)
|
||||
.containsAnyOf("30");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.chat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public class ActorsFilms {
|
||||
|
||||
private String actor;
|
||||
|
||||
private List<String> movies;
|
||||
|
||||
public ActorsFilms() {
|
||||
}
|
||||
|
||||
public String getActor() {
|
||||
return this.actor;
|
||||
}
|
||||
|
||||
public void setActor(String actor) {
|
||||
this.actor = actor;
|
||||
}
|
||||
|
||||
public List<String> getMovies() {
|
||||
return this.movies;
|
||||
}
|
||||
|
||||
public void setMovies(List<String> movies) {
|
||||
this.movies = movies;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.moonshot.MoonshotTestConfiguration;
|
||||
import org.springframework.ai.moonshot.api.MockWeatherService;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = MoonshotTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
class MoonshotChatModelFunctionCallingIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModelFunctionCallingIT.class);
|
||||
|
||||
@Autowired
|
||||
ChatModel chatModel;
|
||||
|
||||
private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
|
||||
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
|
||||
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"lat": {
|
||||
"type": "number",
|
||||
"description": "The city latitude"
|
||||
},
|
||||
"lon": {
|
||||
"type": "number",
|
||||
"description": "The city longitude"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["C", "F"]
|
||||
}
|
||||
},
|
||||
"required": ["location", "lat", "lon", "unit"]
|
||||
}
|
||||
"""));
|
||||
|
||||
@Test
|
||||
void functionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void streamFunctionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = MoonshotChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getText)
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.joining());
|
||||
logger.info("Response: {}", content);
|
||||
|
||||
assertThat(content).contains("30", "10", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void toolFunctionCallWithUsage() {
|
||||
var promptOptions = MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.tools(Arrays.asList(FUNCTION_TOOL))
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
|
||||
promptOptions);
|
||||
|
||||
ChatResponse chatResponse = this.chatModel.call(prompt);
|
||||
assertThat(chatResponse).isNotNull();
|
||||
assertThat(chatResponse.getResult().getOutput());
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
|
||||
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStreamFunctionCallUsage() {
|
||||
var promptOptions = MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.tools(Arrays.asList(FUNCTION_TOOL))
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
|
||||
promptOptions);
|
||||
|
||||
ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
assertThat(chatResponse.getMetadata()).isNotNull();
|
||||
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
|
||||
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.chat;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.moonshot.MoonshotTestConfiguration;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.convert.support.DefaultConversionService;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
@SpringBootTest(classes = MoonshotTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
public class MoonshotChatModelIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModelIT.class);
|
||||
|
||||
@Autowired
|
||||
protected ChatModel chatModel;
|
||||
|
||||
@Autowired
|
||||
protected StreamingChatModel streamingChatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
|
||||
@Test
|
||||
void roleTest() {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
|
||||
}
|
||||
|
||||
@Test
|
||||
void listOutputConverter() {
|
||||
DefaultConversionService conversionService = new DefaultConversionService();
|
||||
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
List five {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getText());
|
||||
assertThat(list).hasSize(5);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void mapOutputConverter() {
|
||||
MapOutputConverter outputConverter = new MapOutputConverter();
|
||||
|
||||
// TODO investigate why additional text was needed to generate the correct output.
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Provide me a List of {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", """
|
||||
numbers from 1 to 9 under they key name 'numbers'.
|
||||
For example here is a list of numbers from 1 to 3 the required format
|
||||
{
|
||||
"numbers": [1, 2, 3]
|
||||
}""", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverter() {
|
||||
|
||||
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography for a random actor.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverterRecords() {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
|
||||
// TODO investigate why standard beanoutput converter text is not working and
|
||||
// additional text to specify
|
||||
// json format is needed. The response without the additional text returns "null"
|
||||
// for actor.
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
|
||||
Your response should be without ```json``` and $schema
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanStreamOutputConverterRecords() {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
|
||||
your response should be without ```json``` and $shcema.
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = this.streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getText)
|
||||
.collect(Collectors.joining());
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
}
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.moonshot.chat;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.micrometer.observation.tck.TestObservationRegistry;
|
||||
import io.micrometer.observation.tck.TestObservationRegistryAssert;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.moonshot.MoonshotChatModel;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.ai.moonshot.api.MoonshotApi;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;
|
||||
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames;
|
||||
|
||||
/**
|
||||
* Integration tests for observation instrumentation in {@link MoonshotChatModel}.
|
||||
*
|
||||
* @author Geng Rong
|
||||
* @author Alexandros Pappas
|
||||
*/
|
||||
@SpringBootTest(classes = MoonshotChatModelObservationIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
|
||||
public class MoonshotChatModelObservationIT {
|
||||
|
||||
@Autowired
|
||||
TestObservationRegistry observationRegistry;
|
||||
|
||||
@Autowired
|
||||
MoonshotChatModel chatModel;
|
||||
|
||||
@BeforeEach
|
||||
void beforeEach() {
|
||||
this.observationRegistry.clear();
|
||||
}
|
||||
|
||||
@Test
|
||||
void observationForChatOperation() {
|
||||
|
||||
var options = MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.frequencyPenalty(0.0)
|
||||
.maxTokens(2048)
|
||||
.presencePenalty(0.0)
|
||||
.stop(List.of("this-is-the-end"))
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.build();
|
||||
|
||||
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
|
||||
|
||||
ChatResponse chatResponse = this.chatModel.call(prompt);
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty();
|
||||
|
||||
ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
|
||||
assertThat(responseMetadata).isNotNull();
|
||||
|
||||
validate(responseMetadata);
|
||||
}
|
||||
|
||||
@Test
|
||||
void observationForStreamingChatOperation() {
|
||||
var options = MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.frequencyPenalty(0.0)
|
||||
.maxTokens(2048)
|
||||
.presencePenalty(0.0)
|
||||
.stop(List.of("this-is-the-end"))
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.build();
|
||||
|
||||
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
|
||||
|
||||
Flux<ChatResponse> chatResponseFlux = this.chatModel.stream(prompt);
|
||||
|
||||
List<ChatResponse> responses = chatResponseFlux.collectList().block();
|
||||
assertThat(responses).isNotEmpty();
|
||||
assertThat(responses).hasSizeGreaterThan(10);
|
||||
|
||||
String aggregatedResponse = responses.subList(0, responses.size() - 1)
|
||||
.stream()
|
||||
.map(r -> r.getResult().getOutput().getText())
|
||||
.collect(Collectors.joining());
|
||||
assertThat(aggregatedResponse).isNotEmpty();
|
||||
|
||||
ChatResponse lastChatResponse = responses.get(responses.size() - 1);
|
||||
|
||||
ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata();
|
||||
assertThat(responseMetadata).isNotNull();
|
||||
|
||||
validate(responseMetadata);
|
||||
}
|
||||
|
||||
private void validate(ChatResponseMetadata responseMetadata) {
|
||||
TestObservationRegistryAssert.assertThat(this.observationRegistry)
|
||||
.doesNotHaveAnyRemainingCurrentObservation()
|
||||
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
|
||||
.that()
|
||||
.hasContextualNameEqualTo("chat " + MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
|
||||
AiOperationType.CHAT.value())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.MOONSHOT.value())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
|
||||
MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
|
||||
"[\"this-is-the-end\"]")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_TOP_K.asString())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]")
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
|
||||
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(),
|
||||
String.valueOf(responseMetadata.getUsage().getCompletionTokens()))
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
|
||||
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
|
||||
.hasBeenStarted()
|
||||
.hasBeenStopped();
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public TestObservationRegistry observationRegistry() {
|
||||
return TestObservationRegistry.create();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MoonshotApi moonshotApi() {
|
||||
return new MoonshotApi(System.getenv("MOONSHOT_API_KEY"));
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MoonshotChatModel moonshotChatModel(MoonshotApi moonshotApi,
|
||||
TestObservationRegistry observationRegistry) {
|
||||
return new MoonshotChatModel(moonshotApi, MoonshotChatOptions.builder().build(),
|
||||
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
You are an AI assistant that helps people find information.
|
||||
Your name is {name}.
|
||||
You should reply to the user's request with your name and also in the style of a {voice}.
|
||||
19
pom.xml
19
pom.xml
@@ -67,13 +67,12 @@
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-openai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-minimax</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-moonshot</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-oci-genai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-oci-genai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-ollama</module>
|
||||
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-postgresml-embedding</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-qianfan</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-stability-ai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-stability-ai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-transformers</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai</module>
|
||||
<module>auto-configurations/models/spring-ai-autoconfigure-model-watsonx-ai</module>
|
||||
@@ -170,14 +169,13 @@
|
||||
<module>models/spring-ai-openai</module>
|
||||
<module>models/spring-ai-postgresml</module>
|
||||
<module>models/spring-ai-qianfan</module>
|
||||
<module>models/spring-ai-stability-ai</module>
|
||||
<module>models/spring-ai-stability-ai</module>
|
||||
<module>models/spring-ai-transformers</module>
|
||||
<module>models/spring-ai-vertex-ai-embedding</module>
|
||||
<module>models/spring-ai-vertex-ai-gemini</module>
|
||||
<module>models/spring-ai-watsonx-ai</module>
|
||||
<module>models/spring-ai-zhipuai</module>
|
||||
<module>models/spring-ai-moonshot</module>
|
||||
|
||||
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-bedrock</module>
|
||||
@@ -193,14 +191,13 @@
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-openai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-postgresml-embedding</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-qianfan</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-stability-ai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-stability-ai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-transformers</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-embedding</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-vertex-ai-gemini</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-watsonx-ai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-zhipuai</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-model-moonshot</module>
|
||||
|
||||
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-mcp-client</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-mcp-server</module>
|
||||
<module>spring-ai-spring-boot-starters/spring-ai-starter-mcp-client-webflux</module>
|
||||
@@ -712,11 +709,9 @@
|
||||
<exclude>org.springframework.ai.huggingface/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.minimax/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.mistralai/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.moonshot/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.oci/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.ollama/**/*IT.java</exclude> <!-- <exclude>org.springframework.ai.openai/**/*IT.java</exclude> -->
|
||||
<exclude>org.springframework.ai.postgresml/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.qianfan/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.stabilityai/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.transformers/**/*IT.java</exclude>
|
||||
<exclude>org.springframework.ai.vertexai.embedding/**/*IT.java</exclude>
|
||||
@@ -765,12 +760,10 @@
|
||||
|
||||
<exclude>org.springframework.ai.autoconfigure.minimax/**/**IT.java</exclude>
|
||||
<exclude>org.springframework.ai.autoconfigure.mistralai/**/**IT.java</exclude>
|
||||
<exclude>org.springframework.ai.autoconfigure.moonshot/**/**IT.java</exclude>
|
||||
<exclude>org.springframework.ai.autoconfigure.oci/**/**IT.java</exclude>
|
||||
<exclude>org.springframework.ai.autoconfigure.ollama/**/**IT.java</exclude>
|
||||
<!-- <exclude>org.springframework.ai.autoconfigure.openai/**/**IT.java</exclude> -->
|
||||
<exclude>org.springframework.ai.autoconfigure.postgresml/**/**IT.java</exclude>
|
||||
<exclude>org.springframework.ai.autoconfigure.qianfan/**/**IT.java</exclude>
|
||||
|
||||
<exclude>org.springframework.ai.autoconfigure.retry/**/**IT.java</exclude>
|
||||
|
||||
|
||||
@@ -210,11 +210,6 @@
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-moonshot</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
@@ -240,11 +235,11 @@
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-qianfan</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-qianfan</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
@@ -534,11 +529,6 @@
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-autoconfigure-model-moonshot</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
@@ -893,11 +883,6 @@
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-starter-model-moonshot</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
|
||||
@@ -1,202 +1,5 @@
|
||||
= Function Calling
|
||||
|
||||
You can register custom Java functions with the `MoonshotChatModel` and have the Moonshot model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions.
|
||||
This allows you to connect the LLM capabilities with external tools and APIs.
|
||||
The Moonshot models are trained to detect when a function should be called and to respond with JSON that adheres to the function signature.
|
||||
This functionality has been moved to the Spring AI Community repository.
|
||||
|
||||
The Moonshot API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation.
|
||||
|
||||
Spring AI provides flexible and user-friendly ways to register and call custom functions.
|
||||
In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. The `description` helps the model to understand when to call the function.
|
||||
|
||||
As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and responds with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results.
|
||||
|
||||
Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`.
|
||||
|
||||
Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code.
|
||||
The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion Builder utility class to simplify the implementation and registration of Java callback functions.
|
||||
|
||||
// Additionally, the Auto-Configuration provides a way to auto-register any Function<I, O> beans definition as function calling candidates in the `ChatModel`.
|
||||
|
||||
|
||||
== How it works
|
||||
|
||||
Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location.
|
||||
|
||||
We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt.
|
||||
|
||||
For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function.
|
||||
The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response.
|
||||
|
||||
The model-client interaction is illustrated in the <<spring-ai-function-calling-flow>> diagram.
|
||||
|
||||
Spring AI greatly simplifies code you need to write to support function invocation.
|
||||
It brokers the function invocation conversation for you.
|
||||
You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options.
|
||||
You can also reference multiple function bean names in your prompt.
|
||||
|
||||
== Quick Start
|
||||
|
||||
Let's create a chatbot that answer questions by calling our own function.
|
||||
To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location.
|
||||
|
||||
When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON.
|
||||
|
||||
Our function calls some SaaS based weather service API and returns the weather response back to the model to complete the conversation. In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations.
|
||||
|
||||
The following `MockWeatherService.java` represents the weather service API:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
public class MockWeatherService implements Function<Request, Response> {
|
||||
|
||||
public enum Unit { C, F }
|
||||
public record Request(String location, Unit unit) {}
|
||||
public record Response(double temp, Unit unit) {}
|
||||
|
||||
public Response apply(Request request) {
|
||||
return new Response(30.0, Unit.C);
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
=== Registering Functions as Beans
|
||||
|
||||
With the link:../minimax-chat.html#_auto_configuration[MoonshotChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context.
|
||||
|
||||
We start with describing the most POJO friendly options.
|
||||
|
||||
|
||||
==== Plain Java Functions
|
||||
|
||||
In this approach you define `@Beans` in your application context as you would any other Spring managed object.
|
||||
|
||||
Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallback` instance that adds the logic for it being invoked via the AI model.
|
||||
The name of the `@Bean` is passed as a `ChatOption`.
|
||||
|
||||
|
||||
[source,java]
|
||||
----
|
||||
@Configuration
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
@Description("Get the weather in location") // function description
|
||||
public Function<MockWeatherService.Request, MockWeatherService.Response> weatherFunction1() {
|
||||
return new MockWeatherService();
|
||||
}
|
||||
...
|
||||
}
|
||||
----
|
||||
|
||||
The `@Description` annotation is optional and provides a function description (2) that helps the model to understand when to call the function. It is an important property to set to help the AI model determine what client side function to invoke.
|
||||
|
||||
Another option to provide the description of the function is to the `@JacksonDescription` annotation on the `MockWeatherService.Request` to provide the function description:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
|
||||
@Configuration
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public Function<Request, Response> currentWeather3() { // (1) bean name as function name.
|
||||
return new MockWeatherService();
|
||||
}
|
||||
...
|
||||
}
|
||||
|
||||
@JsonClassDescription("Get the weather in location") // (2) function description
|
||||
public record Request(String location, Unit unit) {}
|
||||
----
|
||||
|
||||
It is a best practice to annotate the request object with information such that the generates JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke.
|
||||
|
||||
The link:https://github.com/spring-projects/spring-ai/blob/main/auto-configurations/models/spring-ai-autoconfigure-model-moonshot/src/test/java/org/springframework/ai/model/moonshot/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java[FunctionCallbackWithPlainFunctionBeanIT.java] demonstrates this approach.
|
||||
|
||||
|
||||
==== FunctionCallback Wrapper
|
||||
|
||||
Another way register a function is to create `FunctionCallback` instance like this:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
@Configuration
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
|
||||
return FunctionCallback.builder()
|
||||
.function("CurrentWeather", new MockWeatherService()) // (1) function name and instance
|
||||
.description("Get the weather in location") // (2) function description
|
||||
.inputType(MockWeatherService.Request.class) // (3) function signature
|
||||
.build();
|
||||
}
|
||||
...
|
||||
}
|
||||
----
|
||||
|
||||
It wraps the 3rd party, `MockWeatherService` function and registers it as a `CurrentWeather` function with the `MoonshotChatModel`.
|
||||
It also provides a description (2) and the function signature (3) to let the model know what arguments the function expects.
|
||||
|
||||
NOTE: By default, the response converter does a JSON serialization of the Response object.
|
||||
|
||||
NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class.
|
||||
|
||||
=== Specifying functions in Chat Options
|
||||
|
||||
To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
MoonshotChatModel chatModel = ...
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage),
|
||||
MoonshotChatOptions.builder().function("CurrentWeather").build())); // (1) Enable the function
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
----
|
||||
|
||||
// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling.
|
||||
|
||||
Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and the final response will be something like this:
|
||||
|
||||
----
|
||||
Here is the current weather for the requested cities:
|
||||
- San Francisco, CA: 30.0°C
|
||||
- Tokyo, Japan: 10.0°C
|
||||
- Paris, France: 15.0°C
|
||||
----
|
||||
|
||||
The link:https://github.com/spring-projects/spring-ai/blob/main/auto-configurations/models/spring-ai-autoconfigure-model-moonshot/src/test/java/org/springframework/ai/model/moonshot/autoconfigure/tool/MoonshotFunctionCallbackIT.java[MoonshotFunctionCallbackIT.java] test demo this approach.
|
||||
|
||||
|
||||
=== Register/Call Functions with Prompt Options
|
||||
|
||||
In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
MoonshotChatModel chatModel = ...
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
var promptOptions = MoonshotChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeather", new MockWeatherService()) // (1) function name
|
||||
.description("Get the weather in location") // (2) function description
|
||||
.inputType(MockWeatherService.Request.class) // (3) function signature
|
||||
.build())) // function code
|
||||
.build();
|
||||
|
||||
ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions));
|
||||
----
|
||||
|
||||
NOTE: The in-prompt registered functions are enabled by default for the duration of this request.
|
||||
|
||||
This approach allows to dynamically chose different functions to be called based on the user input.
|
||||
|
||||
The https://github.com/spring-projects/spring-ai/blob/main/auto-configurations/models/spring-ai-autoconfigure-model-moonshot/src/test/java/org/springframework/ai/model/moonshot/autoconfigure/tool/FunctionCallbackInPromptIT.java[FunctionCallbackInPromptIT.java] integration test provides a complete example of how to register a function with the `MoonshotChatModel` and use it in a prompt request.
|
||||
Please visit https://github.com/spring-ai-community/moonshot for the latest version.
|
||||
@@ -1,268 +1,5 @@
|
||||
= Moonshot AI Chat
|
||||
|
||||
Spring AI supports the various AI language models from Moonshot AI. You can interact with Moonshot AI language models and create a multilingual conversational assistant based on Moonshot models.
|
||||
This functionality has been moved to the Spring AI Community repository.
|
||||
|
||||
== Prerequisites
|
||||
|
||||
You will need to create an API with Moonshot to access Moonshot AI language models.
|
||||
Create an account at https://platform.moonshot.cn/console[Moonshot AI registration page] and generate the token on the https://platform.moonshot.cn/console/api-keys/[API Keys page].
|
||||
The Spring AI project defines a configuration property named `spring.ai.moonshot.api-key` that you should set to the value of the `API Key` obtained from https://platform.moonshot.cn/console/api-keys/[API Keys page].
|
||||
Exporting an environment variable is one way to set that configuration property:
|
||||
|
||||
[source,shell]
|
||||
----
|
||||
export SPRING_AI_MOONSHOT_API_KEY=<INSERT KEY HERE>
|
||||
----
|
||||
|
||||
=== Add Repositories and BOM
|
||||
|
||||
Spring AI artifacts are published in Maven Central and Spring Snapshot repositories.
|
||||
Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system.
|
||||
|
||||
To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system.
|
||||
|
||||
|
||||
|
||||
== Auto-configuration
|
||||
|
||||
[NOTE]
|
||||
====
|
||||
There has been a significant change in the Spring AI auto-configuration, starter modules' artifact names.
|
||||
Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information.
|
||||
====
|
||||
|
||||
Spring AI provides Spring Boot auto-configuration for the Moonshot Chat Model.
|
||||
To enable it add the following dependency to your project's Maven `pom.xml` file:
|
||||
|
||||
[source, xml]
|
||||
----
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-starter-model-moonshot</artifactId>
|
||||
</dependency>
|
||||
----
|
||||
|
||||
or to your Gradle `build.gradle` build file.
|
||||
|
||||
[source,groovy]
|
||||
----
|
||||
dependencies {
|
||||
implementation 'org.springframework.ai:spring-ai-starter-model-moonshot'
|
||||
}
|
||||
----
|
||||
|
||||
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
|
||||
|
||||
=== Chat Properties
|
||||
|
||||
==== Retry Properties
|
||||
|
||||
The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the Moonshot AI Chat model.
|
||||
|
||||
[cols="3,5,1", stripes=even]
|
||||
|====
|
||||
| Property | Description | Default
|
||||
|
||||
| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10
|
||||
| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec.
|
||||
| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5
|
||||
| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min.
|
||||
| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false
|
||||
| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty
|
||||
| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty
|
||||
|====
|
||||
|
||||
==== Connection Properties
|
||||
|
||||
The prefix `spring.ai.moonshot` is used as the property prefix that lets you connect to Moonshot.
|
||||
|
||||
[cols="3,5,1", stripes=even]
|
||||
|====
|
||||
| Property | Description | Default
|
||||
|
||||
| spring.ai.moonshot.base-url | The URL to connect to | https://api.moonshot.cn
|
||||
| spring.ai.moonshot.api-key | The API Key | -
|
||||
|====
|
||||
|
||||
==== Configuration Properties
|
||||
|
||||
[NOTE]
|
||||
====
|
||||
Enabling and disabling of the chat auto-configurations are now configured via top level properties with the prefix `spring.ai.model.chat`.
|
||||
|
||||
To enable, spring.ai.model.chat=moonshot (It is enabled by default)
|
||||
|
||||
To disable, spring.ai.model.chat=none (or any value which doesn't match moonshot)
|
||||
|
||||
This change is done to allow configuration of multiple models.
|
||||
====
|
||||
|
||||
The prefix `spring.ai.moonshot.chat` is the property prefix that lets you configure the chat model implementation for Moonshot.
|
||||
|
||||
[cols="3,5,1", stripes=even]
|
||||
|====
|
||||
| Property | Description | Default
|
||||
|
||||
| spring.ai.moonshot.chat.enabled (Removed and no longer valid) | Enable Moonshot chat model. | true
|
||||
| spring.ai.model.chat | Enable Moonshot chat model. | moonshot
|
||||
| spring.ai.moonshot.chat.base-url | Optional overrides the spring.ai.moonshot.base-url to provide chat specific url | -
|
||||
| spring.ai.moonshot.chat.api-key | Optional overrides the spring.ai.moonshot.api-key to provide chat specific api-key | -
|
||||
| spring.ai.moonshot.chat.options.model | This is the Moonshot Chat model to use | `moonshot-v1-8k` (the `moonshot-v1-8k`, `moonshot-v1-32k`, and `moonshot-v1-128k` point to the latest model versions)
|
||||
| spring.ai.moonshot.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | -
|
||||
| spring.ai.moonshot.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.7
|
||||
| spring.ai.moonshot.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | 1.0
|
||||
| spring.ai.moonshot.chat.options.n | How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Default value is 1 and cannot be greater than 5. Specifically, when the temperature is very small and close to 0, we can only return 1 result. If n is already set and>1 at this time, service will return an illegal input parameter (invalid_request_error) | 1
|
||||
| spring.ai.moonshot.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | 0.0f
|
||||
| spring.ai.moonshot.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f
|
||||
| spring.ai.moonshot.chat.options.stop | Up to 5 sequences where the API will stop generating further tokens. Each string must not exceed 32 bytes | -
|
||||
|====
|
||||
|
||||
NOTE: You can override the common `spring.ai.moonshot.base-url` and `spring.ai.moonshot.api-key` for the `ChatModel` implementations.
|
||||
The `spring.ai.moonshot.chat.base-url` and `spring.ai.moonshot.chat.api-key` properties if set take precedence over the common properties.
|
||||
This is useful if you want to use different Moonshot accounts for different models and different model endpoints.
|
||||
|
||||
TIP: All properties prefixed with `spring.ai.moonshot.chat.options` can be overridden at runtime by adding a request specific <<chat-options>> to the `Prompt` call.
|
||||
|
||||
== Runtime Options [[chat-options]]
|
||||
|
||||
The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java[MoonshotChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc.
|
||||
|
||||
On start-up, the default options can be configured with the `MoonshotChatModel(api, options)` constructor or the `spring.ai.moonshot.chat.options.*` properties.
|
||||
|
||||
At run-time you can override the default options by adding new, request specific, options to the `Prompt` call.
|
||||
For example to override the default model and temperature for a specific request:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
ChatResponse response = chatModel.call(
|
||||
new Prompt(
|
||||
"Generate the names of 5 famous pirates.",
|
||||
MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.temperature(0.5)
|
||||
.build()
|
||||
));
|
||||
----
|
||||
|
||||
TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java[MoonshotChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-chat-client/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()].
|
||||
|
||||
== Sample Controller (Auto-configuration)
|
||||
|
||||
https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-starter-model-moonshot` to your pom (or gradle) dependencies.
|
||||
|
||||
Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Moonshot Chat model:
|
||||
|
||||
[source,application.properties]
|
||||
----
|
||||
spring.ai.moonshot.api-key=YOUR_API_KEY
|
||||
spring.ai.moonshot.chat.options.model=moonshot-v1-8k
|
||||
spring.ai.moonshot.chat.options.temperature=0.7
|
||||
----
|
||||
|
||||
TIP: replace the `api-key` with your Moonshot credentials.
|
||||
|
||||
This will create a `MoonshotChatModel` implementation that you can inject into your class.
|
||||
Here is an example of a simple `@Controller` class that uses the chat model for text generations.
|
||||
|
||||
[source,java]
|
||||
----
|
||||
@RestController
|
||||
public class ChatController {
|
||||
|
||||
private final MoonshotChatModel chatModel;
|
||||
|
||||
@Autowired
|
||||
public ChatController(MoonshotChatModel chatModel) {
|
||||
this.chatModel = chatModel;
|
||||
}
|
||||
|
||||
@GetMapping("/ai/generate")
|
||||
public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
|
||||
return Map.of("generation", this.chatModel.call(message));
|
||||
}
|
||||
|
||||
@GetMapping("/ai/generateStream")
|
||||
public Flux<ChatResponse> generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
|
||||
var prompt = new Prompt(new UserMessage(message));
|
||||
return this.chatModel.stream(prompt);
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
== Manual Configuration
|
||||
|
||||
The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java[MoonshotChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <<low-level-api>> to connect to the Moonshot service.
|
||||
|
||||
Add the `spring-ai-moonshot` dependency to your project's Maven `pom.xml` file:
|
||||
|
||||
[source, xml]
|
||||
----
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-moonshot</artifactId>
|
||||
</dependency>
|
||||
----
|
||||
|
||||
or to your Gradle `build.gradle` build file.
|
||||
|
||||
[source,groovy]
|
||||
----
|
||||
dependencies {
|
||||
implementation 'org.springframework.ai:spring-ai-moonshot'
|
||||
}
|
||||
----
|
||||
|
||||
TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
|
||||
|
||||
Next, create a `MoonshotChatModel` and use it for text generations:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
var moonshotApi = new MoonshotApi(System.getenv("MOONSHOT_API_KEY"));
|
||||
|
||||
var chatModel = new MoonshotChatModel(this.moonshotApi, MoonshotChatOptions.builder()
|
||||
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
|
||||
.temperature(0.4)
|
||||
.maxTokens(200)
|
||||
.build());
|
||||
|
||||
ChatResponse response = this.chatModel.call(
|
||||
new Prompt("Generate the names of 5 famous pirates."));
|
||||
|
||||
// Or with streaming responses
|
||||
Flux<ChatResponse> streamResponse = this.chatModel.stream(
|
||||
new Prompt("Generate the names of 5 famous pirates."));
|
||||
----
|
||||
|
||||
The `MoonshotChatOptions` provides the configuration information for the chat requests.
|
||||
The `MoonshotChatOptions.Builder` is fluent options builder.
|
||||
|
||||
=== Low-level Moonshot Api Client [[low-level-api]]
|
||||
|
||||
The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java[MoonshotApi] provides is lightweight Java client for link:https://platform.moonshot.cn/docs/api-reference[Moonshot AI API].
|
||||
|
||||
Here is a simple snippet how to use the api programmatically:
|
||||
|
||||
[source,java]
|
||||
----
|
||||
MoonshotApi moonshotApi =
|
||||
new MoonshotApi(System.getenv("MOONSHOT_API_KEY"));
|
||||
|
||||
ChatCompletionMessage chatCompletionMessage =
|
||||
new ChatCompletionMessage("Hello world", Role.USER);
|
||||
|
||||
// Sync request
|
||||
ResponseEntity<ChatCompletion> response = this.moonshotApi.chatCompletionEntity(
|
||||
new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, false));
|
||||
|
||||
// Streaming request
|
||||
Flux<ChatCompletionChunk> streamResponse = this.moonshotApi.chatCompletionStream(
|
||||
new ChatCompletionRequest(List.of(this.chatCompletionMessage), MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue(), 0.7, true));
|
||||
----
|
||||
|
||||
Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java[MoonshotApi.java]'s JavaDoc for further information.
|
||||
|
||||
==== MoonshotApi Samples
|
||||
* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java[MoonshotApiIT.java] test provides some general examples how to use the lightweight library.
|
||||
|
||||
* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java[MoonshotApiToolFunctionCallIT.java] test shows how to use the low-level API to call tool functions.
|
||||
Please visit https://github.com/spring-ai-community/moonshot for the latest version.
|
||||
Reference in New Issue
Block a user