fix: #2146 getting default AWS region using DefaultAwsRegionProviderChain

Signed-off-by: Andrei Shakirin <andrei.shakirin@gmail.com>
This commit is contained in:
Andrei Shakirin
2025-05-16 21:25:58 +02:00
committed by Christian Tzolov
parent cd3fc2f816
commit 0bbccf8ca0
4 changed files with 168 additions and 3 deletions

View File

@@ -40,8 +40,10 @@ import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
@@ -788,6 +790,12 @@ public class BedrockProxyChatModel implements ChatModel {
private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
private Builder() {
try {
region = DefaultAwsRegionProviderChain.builder().build().getRegion();
}
catch (SdkClientException e) {
logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", e);
}
}
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.converse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class BedrockProxyChatModelTest {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
@Test
void shouldIgnoreExceptionAndUseDefault() {
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
when(awsRegionProviderBuilder.build().getRegion())
.thenThrow(SdkClientException.builder().message("failed load").build());
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
BedrockProxyChatModel.builder().build();
}
}
}

View File

@@ -30,13 +30,16 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.ObjectUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.Sinks.EmitFailureHandler;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
@@ -148,14 +151,12 @@ public abstract class AbstractBedrockApi<I, O, SO> {
Assert.hasText(modelId, "Model id must not be empty");
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
Assert.notNull(region, "Region must not be empty");
Assert.notNull(objectMapper, "Object mapper must not be null");
Assert.notNull(timeout, "Timeout must not be null");
this.modelId = modelId;
this.objectMapper = objectMapper;
this.region = region;
this.region = getRegion(region);
this.client = BedrockRuntimeClient.builder()
.region(this.region)
@@ -339,5 +340,17 @@ public abstract class AbstractBedrockApi<I, O, SO> {
@JsonProperty("outputTokenCount") Long outputTokenCount,
@JsonProperty("invocationLatency") Long invocationLatency) {
}
private Region getRegion(Region region) {
if (ObjectUtils.isEmpty(region)) {
try {
return DefaultAwsRegionProviderChain.builder().build().getRegion();
} catch (SdkClientException e) {
throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e);
}
} else {
return region;
}
}
}
// @formatter:on

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.api;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class AbstractBedrockApiTest {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
@Mock
private AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
@Mock
private ObjectMapper objectMapper = mock(ObjectMapper.class);
@Test
void shouldLoadRegionFromAwsDefaults() {
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
when(awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1);
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
AbstractBedrockApi<Object, Object, Object> testBedrockApi = new TestBedrockApi("modelId",
awsCredentialsProvider, null, objectMapper, Duration.ofMinutes(5));
assertThat(testBedrockApi.getRegion()).isEqualTo(Region.AF_SOUTH_1);
}
}
@Test
void shouldThrowIllegalArgumentIfAwsDefaultsFailed() {
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
when(awsRegionProviderBuilder.build().getRegion())
.thenThrow(SdkClientException.builder().message("failed load").build());
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
assertThatThrownBy(() -> new TestBedrockApi("modelId", awsCredentialsProvider, null, objectMapper,
Duration.ofMinutes(5)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("failed load");
}
}
private static class TestBedrockApi extends AbstractBedrockApi<Object, Object, Object> {
protected TestBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout) {
super(modelId, credentialsProvider, region, objectMapper, timeout);
}
@Override
protected Object embedding(Object request) {
return null;
}
@Override
protected Object chatCompletion(Object request) {
return null;
}
@Override
protected Object internalInvocation(Object request, Class<Object> clazz) {
return null;
}
}
}