From 4ba5ea345281aee3c32fe341348664080b73e691 Mon Sep 17 00:00:00 2001 From: Oleg Zhurakousky Date: Thu, 30 Mar 2023 14:53:29 +0200 Subject: [PATCH] GH-1018 Ensure AWS adapter can pass raw InputStream Resolves #1018 --- .../function/adapter/aws/AWSLambdaUtils.java | 22 +++++++ .../function/adapter/aws/FunctionInvoker.java | 2 +- .../adapter/aws/FunctionInvokerTests.java | 63 +++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java index c9b4f5704..a7908b5ee 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java @@ -16,6 +16,8 @@ package org.springframework.cloud.function.adapter.aws; +import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.util.HashMap; @@ -31,7 +33,9 @@ import org.springframework.cloud.function.json.JsonMapper; import org.springframework.http.HttpStatus; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.support.GenericMessage; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.StreamUtils; /** * @@ -77,6 +81,23 @@ public final class AWSLambdaUtils { || typeName.equals("com.amazonaws.services.lambda.runtime.events.KinesisEvent"); } + @SuppressWarnings("rawtypes") + public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException { + if (inputType != null && FunctionTypeUtils.isMessage(inputType)) { + inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0); + } + if (inputType != null && InputStream.class.isAssignableFrom(FunctionTypeUtils.getRawType(inputType))) { + MessageBuilder msgBuilder = MessageBuilder.withPayload(payload); + if (context != null) { + msgBuilder.setHeader(AWSLambdaUtils.AWS_CONTEXT, context); + } + return msgBuilder.build(); + } + else { + return generateMessage(StreamUtils.copyToByteArray(payload), inputType, isSupplier, jsonMapper, context); + } + } + public static Message generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) { return generateMessage(payload, inputType, isSupplier, jsonMapper, null); } @@ -87,6 +108,7 @@ public final class AWSLambdaUtils { logger.info("Received: " + new String(payload, StandardCharsets.UTF_8)); } + Object structMessage = jsonMapper.fromJson(payload, Object.class); boolean isApiGateway = structMessage instanceof Map && (((Map) structMessage).containsKey("httpMethod") || diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java index e1e706477..54ea419aa 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java @@ -80,7 +80,7 @@ public class FunctionInvoker implements RequestStreamHandler { @Override public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException { Message requestMessage = AWSLambdaUtils - .generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper, context); + .generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context); Object response = this.function.apply(requestMessage); byte[] responseBytes = this.buildResult(requestMessage, response); diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java index 226f2117c..5ff2151f6 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java @@ -57,6 +57,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.converter.AbstractMessageConverter; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MimeType; +import org.springframework.util.StreamUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.fail; @@ -971,6 +972,40 @@ public class FunctionInvokerTests { assertThat(result.get("body")).isEqualTo("\"boom\""); } + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testApiGatewayInAndOutInputStream() throws Exception { + System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "echoInputStreamToString"); + FunctionInvoker invoker = new FunctionInvoker(); + + InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes()); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + invoker.handleRequest(targetStream, output, null); + + Map result = mapper.readValue(output.toByteArray(), Map.class); + assertThat(result.get("body")).isEqualTo("hello"); + Map headers = (Map) result.get("headers"); + assertThat(headers).isNotEmpty(); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testApiGatewayInAndOutInputStreamMsg() throws Exception { + System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "echoInputStreamMsgToString"); + FunctionInvoker invoker = new FunctionInvoker(); + + InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes()); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + invoker.handleRequest(targetStream, output, null); + + Map result = mapper.readValue(output.toByteArray(), Map.class); + assertThat(result.get("body")).isEqualTo("hello"); + Map headers = (Map) result.get("headers"); + assertThat(headers).isNotEmpty(); + } + @SuppressWarnings("rawtypes") @Test public void testApiGatewayInAndOut() throws Exception { @@ -1400,6 +1435,34 @@ public class FunctionInvokerTests { }; } + @Bean + + public Function echoInputStreamToString() { + return is -> { + try { + String result = StreamUtils.copyToString(is, StandardCharsets.UTF_8); + return result; + } + catch (Exception e) { + throw new RuntimeException(e); + } + }; + } + + @Bean + + public Function, String> echoInputStreamMsgToString() { + return msg -> { + try { + String result = StreamUtils.copyToString(msg.getPayload(), StandardCharsets.UTF_8); + return result; + } + catch (Exception e) { + throw new RuntimeException(e); + } + }; + } + @Bean public Function inputOutputApiEvent() { return v -> {