fix: #2146 getting default AWS region using DefaultAwsRegionProviderChain
Signed-off-by: Andrei Shakirin <andrei.shakirin@gmail.com>
This commit is contained in:
committed by
Christian Tzolov
parent
cd3fc2f816
commit
0bbccf8ca0
@@ -40,8 +40,10 @@ import reactor.core.scheduler.Schedulers;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.core.document.Document;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
|
||||
@@ -788,6 +790,12 @@ public class BedrockProxyChatModel implements ChatModel {
|
||||
private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
|
||||
|
||||
private Builder() {
|
||||
try {
|
||||
region = DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
}
|
||||
catch (SdkClientException e) {
|
||||
logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", e);
|
||||
}
|
||||
}
|
||||
|
||||
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.bedrock.converse;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Answers;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
|
||||
|
||||
import static org.mockito.Mockito.mockStatic;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class BedrockProxyChatModelTest {
|
||||
|
||||
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
|
||||
|
||||
@Test
|
||||
void shouldIgnoreExceptionAndUseDefault() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion())
|
||||
.thenThrow(SdkClientException.builder().message("failed load").build());
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
BedrockProxyChatModel.builder().build();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -30,13 +30,16 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Sinks;
|
||||
import reactor.core.publisher.Sinks.EmitFailureHandler;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
|
||||
@@ -148,14 +151,12 @@ public abstract class AbstractBedrockApi<I, O, SO> {
|
||||
|
||||
Assert.hasText(modelId, "Model id must not be empty");
|
||||
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
|
||||
Assert.notNull(region, "Region must not be empty");
|
||||
Assert.notNull(objectMapper, "Object mapper must not be null");
|
||||
Assert.notNull(timeout, "Timeout must not be null");
|
||||
|
||||
this.modelId = modelId;
|
||||
this.objectMapper = objectMapper;
|
||||
this.region = region;
|
||||
|
||||
this.region = getRegion(region);
|
||||
|
||||
this.client = BedrockRuntimeClient.builder()
|
||||
.region(this.region)
|
||||
@@ -339,5 +340,17 @@ public abstract class AbstractBedrockApi<I, O, SO> {
|
||||
@JsonProperty("outputTokenCount") Long outputTokenCount,
|
||||
@JsonProperty("invocationLatency") Long invocationLatency) {
|
||||
}
|
||||
|
||||
private Region getRegion(Region region) {
|
||||
if (ObjectUtils.isEmpty(region)) {
|
||||
try {
|
||||
return DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
} catch (SdkClientException e) {
|
||||
throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e);
|
||||
}
|
||||
} else {
|
||||
return region;
|
||||
}
|
||||
}
|
||||
}
|
||||
// @formatter:on
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.bedrock.api;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Answers;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockedStatic;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class AbstractBedrockApiTest {
|
||||
|
||||
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
|
||||
|
||||
@Mock
|
||||
private AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
|
||||
|
||||
@Mock
|
||||
private ObjectMapper objectMapper = mock(ObjectMapper.class);
|
||||
|
||||
@Test
|
||||
void shouldLoadRegionFromAwsDefaults() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1);
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
AbstractBedrockApi<Object, Object, Object> testBedrockApi = new TestBedrockApi("modelId",
|
||||
awsCredentialsProvider, null, objectMapper, Duration.ofMinutes(5));
|
||||
assertThat(testBedrockApi.getRegion()).isEqualTo(Region.AF_SOUTH_1);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldThrowIllegalArgumentIfAwsDefaultsFailed() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion())
|
||||
.thenThrow(SdkClientException.builder().message("failed load").build());
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
assertThatThrownBy(() -> new TestBedrockApi("modelId", awsCredentialsProvider, null, objectMapper,
|
||||
Duration.ofMinutes(5)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("failed load");
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestBedrockApi extends AbstractBedrockApi<Object, Object, Object> {
|
||||
|
||||
protected TestBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
|
||||
ObjectMapper objectMapper, Duration timeout) {
|
||||
super(modelId, credentialsProvider, region, objectMapper, timeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object embedding(Object request) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object chatCompletion(Object request) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object internalInvocation(Object request, Class<Object> clazz) {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user