From 205b6174ccf31f5df797ca712ec489eb0dbce009 Mon Sep 17 00:00:00 2001 From: Oleg Zhurakousky Date: Mon, 23 Jan 2023 14:55:46 +0100 Subject: [PATCH] GH-973 Ensure that AWS isBase64Encoded is set dynamically Resolves #973 --- .../function/adapter/aws/AWSLambdaUtils.java | 22 +++++++++----- .../adapter/aws/AWSTypesMessageConverter.java | 4 +++ .../adapter/aws/FunctionInvokerTests.java | 29 +++++++++++++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java index 3f3da31e7..c9b4f5704 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java @@ -46,6 +46,14 @@ public final class AWSLambdaUtils { static final String AWS_EVENT = "aws-event"; + static final String IS_BASE64_ENCODED = "isBase64Encoded"; + + static final String STATUS_CODE = "statusCode"; + + static final String BODY = "body"; + + static final String HEADERS = "headers"; + /** * The name of the headers that stores AWS Context object. */ @@ -130,18 +138,19 @@ public final class AWSLambdaUtils { byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes() : extractPayload((Message) responseMessage, objectMapper); if (requestMessage.getHeaders().containsKey(AWS_API_GATEWAY) && ((boolean) requestMessage.getHeaders().get(AWS_API_GATEWAY))) { Map response = new HashMap(); - response.put("isBase64Encoded", false); + response.put(IS_BASE64_ENCODED, responseMessage != null && responseMessage.getHeaders().containsKey(IS_BASE64_ENCODED) + ? responseMessage.getHeaders().get(IS_BASE64_ENCODED) : false); AtomicReference headers = new AtomicReference<>(); int statusCode = HttpStatus.OK.value(); if (responseMessage != null) { headers.set(responseMessage.getHeaders()); - statusCode = headers.get().containsKey("statusCode") - ? (int) headers.get().get("statusCode") + statusCode = headers.get().containsKey(STATUS_CODE) + ? (int) headers.get().get(STATUS_CODE) : HttpStatus.OK.value(); } - response.put("statusCode", statusCode); + response.put(STATUS_CODE, statusCode); if (isRequestKinesis(requestMessage)) { HttpStatus httpStatus = HttpStatus.valueOf(statusCode); response.put("statusDescription", httpStatus.toString()); @@ -149,12 +158,11 @@ public final class AWSLambdaUtils { String body = responseMessage == null ? "\"OK\"" : new String(extractPayload((Message) responseMessage, objectMapper), StandardCharsets.UTF_8); - response.put("body", body); - + response.put(BODY, body); if (responseMessage != null) { Map responseHeaders = new HashMap<>(); headers.get().keySet().forEach(key -> responseHeaders.put(key, headers.get().get(key).toString())); - response.put("headers", responseHeaders); + response.put(HEADERS, responseHeaders); } try { diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java index 4c73054d2..5d3e88bc6 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java @@ -17,6 +17,7 @@ package org.springframework.cloud.function.adapter.aws; import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.util.Map; import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer; @@ -100,6 +101,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter { @Override protected Object convertToInternal(Object payload, @Nullable MessageHeaders headers, @Nullable Object conversionHint) { + if (payload instanceof String && headers.containsKey(AWSLambdaUtils.IS_BASE64_ENCODED) && (boolean) headers.get(AWSLambdaUtils.IS_BASE64_ENCODED)) { + return ((String) payload).getBytes(StandardCharsets.UTF_8); + } return jsonMapper.toJson(payload); } diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java index 3c41663a3..99701a9ca 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Collections; import java.util.Map; import java.util.function.Consumer; @@ -47,6 +48,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.cloud.function.json.JacksonMapper; import org.springframework.cloud.function.json.JsonMapper; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -948,6 +950,25 @@ public class FunctionInvokerTests { assertThat(result.get("body")).isEqualTo("\"Hello from Lambda\""); } + @Test + public void testResponseBase64Encoded() throws Exception { + System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "echoStringMessage"); + FunctionInvoker invoker = new FunctionInvoker(); + + InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes()); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + invoker.handleRequest(targetStream, output, null); + + JsonMapper mapper = new JacksonMapper(new ObjectMapper()); + + String result = new String(output.toByteArray(), StandardCharsets.UTF_8); + Map resultMap = mapper.fromJson(result, Map.class); + assertThat((boolean) resultMap.get(AWSLambdaUtils.IS_BASE64_ENCODED)).isTrue(); + String body = new String(Base64.getDecoder().decode((String) resultMap.get(AWSLambdaUtils.BODY)), StandardCharsets.UTF_8); + assertThat(body).isEqualTo("hello"); + } + @SuppressWarnings("rawtypes") @Test public void testApiGatewayAsSupplier() throws Exception { @@ -1373,6 +1394,14 @@ public class FunctionInvokerTests { return () -> "boom"; } + @Bean + public Function, Message> echoStringMessage() { + return m -> { + String encodedPayload = Base64.getEncoder().encodeToString(m.getPayload().getBytes(StandardCharsets.UTF_8)); + return MessageBuilder.withPayload(encodedPayload).setHeader("isBase64Encoded", true).build(); + }; + } + @Bean public Consumer consume() {