GH-973 Ensure that AWS isBase64Encoded is set dynamically
Resolves #973
This commit is contained in:
@@ -46,6 +46,14 @@ public final class AWSLambdaUtils {
|
||||
|
||||
static final String AWS_EVENT = "aws-event";
|
||||
|
||||
static final String IS_BASE64_ENCODED = "isBase64Encoded";
|
||||
|
||||
static final String STATUS_CODE = "statusCode";
|
||||
|
||||
static final String BODY = "body";
|
||||
|
||||
static final String HEADERS = "headers";
|
||||
|
||||
/**
|
||||
* The name of the headers that stores AWS Context object.
|
||||
*/
|
||||
@@ -130,18 +138,19 @@ public final class AWSLambdaUtils {
|
||||
byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes() : extractPayload((Message<Object>) responseMessage, objectMapper);
|
||||
if (requestMessage.getHeaders().containsKey(AWS_API_GATEWAY) && ((boolean) requestMessage.getHeaders().get(AWS_API_GATEWAY))) {
|
||||
Map<String, Object> response = new HashMap<String, Object>();
|
||||
response.put("isBase64Encoded", false);
|
||||
response.put(IS_BASE64_ENCODED, responseMessage != null && responseMessage.getHeaders().containsKey(IS_BASE64_ENCODED)
|
||||
? responseMessage.getHeaders().get(IS_BASE64_ENCODED) : false);
|
||||
|
||||
AtomicReference<MessageHeaders> headers = new AtomicReference<>();
|
||||
int statusCode = HttpStatus.OK.value();
|
||||
if (responseMessage != null) {
|
||||
headers.set(responseMessage.getHeaders());
|
||||
statusCode = headers.get().containsKey("statusCode")
|
||||
? (int) headers.get().get("statusCode")
|
||||
statusCode = headers.get().containsKey(STATUS_CODE)
|
||||
? (int) headers.get().get(STATUS_CODE)
|
||||
: HttpStatus.OK.value();
|
||||
}
|
||||
|
||||
response.put("statusCode", statusCode);
|
||||
response.put(STATUS_CODE, statusCode);
|
||||
if (isRequestKinesis(requestMessage)) {
|
||||
HttpStatus httpStatus = HttpStatus.valueOf(statusCode);
|
||||
response.put("statusDescription", httpStatus.toString());
|
||||
@@ -149,12 +158,11 @@ public final class AWSLambdaUtils {
|
||||
|
||||
String body = responseMessage == null
|
||||
? "\"OK\"" : new String(extractPayload((Message<Object>) responseMessage, objectMapper), StandardCharsets.UTF_8);
|
||||
response.put("body", body);
|
||||
|
||||
response.put(BODY, body);
|
||||
if (responseMessage != null) {
|
||||
Map<String, String> responseHeaders = new HashMap<>();
|
||||
headers.get().keySet().forEach(key -> responseHeaders.put(key, headers.get().get(key).toString()));
|
||||
response.put("headers", responseHeaders);
|
||||
response.put(HEADERS, responseHeaders);
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.springframework.cloud.function.adapter.aws;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Map;
|
||||
|
||||
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
|
||||
@@ -100,6 +101,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter {
|
||||
@Override
|
||||
protected Object convertToInternal(Object payload, @Nullable MessageHeaders headers,
|
||||
@Nullable Object conversionHint) {
|
||||
if (payload instanceof String && headers.containsKey(AWSLambdaUtils.IS_BASE64_ENCODED) && (boolean) headers.get(AWSLambdaUtils.IS_BASE64_ENCODED)) {
|
||||
return ((String) payload).getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
return jsonMapper.toJson(payload);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
@@ -47,6 +48,7 @@ import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
||||
import org.springframework.cloud.function.json.JacksonMapper;
|
||||
import org.springframework.cloud.function.json.JsonMapper;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
@@ -948,6 +950,25 @@ public class FunctionInvokerTests {
|
||||
assertThat(result.get("body")).isEqualTo("\"Hello from Lambda\"");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResponseBase64Encoded() throws Exception {
|
||||
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
|
||||
System.setProperty("spring.cloud.function.definition", "echoStringMessage");
|
||||
FunctionInvoker invoker = new FunctionInvoker();
|
||||
|
||||
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
invoker.handleRequest(targetStream, output, null);
|
||||
|
||||
JsonMapper mapper = new JacksonMapper(new ObjectMapper());
|
||||
|
||||
String result = new String(output.toByteArray(), StandardCharsets.UTF_8);
|
||||
Map resultMap = mapper.fromJson(result, Map.class);
|
||||
assertThat((boolean) resultMap.get(AWSLambdaUtils.IS_BASE64_ENCODED)).isTrue();
|
||||
String body = new String(Base64.getDecoder().decode((String) resultMap.get(AWSLambdaUtils.BODY)), StandardCharsets.UTF_8);
|
||||
assertThat(body).isEqualTo("hello");
|
||||
}
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
@Test
|
||||
public void testApiGatewayAsSupplier() throws Exception {
|
||||
@@ -1373,6 +1394,14 @@ public class FunctionInvokerTests {
|
||||
return () -> "boom";
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Message<String>, Message<String>> echoStringMessage() {
|
||||
return m -> {
|
||||
String encodedPayload = Base64.getEncoder().encodeToString(m.getPayload().getBytes(StandardCharsets.UTF_8));
|
||||
return MessageBuilder.withPayload(encodedPayload).setHeader("isBase64Encoded", true).build();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@Bean
|
||||
public Consumer<String> consume() {
|
||||
|
||||
Reference in New Issue
Block a user