GH-1018 Ensure AWS adapter can pass raw InputStream

Resolves #1018
This commit is contained in:
Oleg Zhurakousky
2023-03-30 14:53:29 +02:00
parent 7365debf44
commit 4ba5ea3452
3 changed files with 86 additions and 1 deletions

View File

@@ -16,6 +16,8 @@
package org.springframework.cloud.function.adapter.aws;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
@@ -31,7 +33,9 @@ import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.http.HttpStatus;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.StreamUtils;
/**
*
@@ -77,6 +81,23 @@ public final class AWSLambdaUtils {
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.KinesisEvent");
}
@SuppressWarnings("rawtypes")
public static Message generateMessage(InputStream payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper, Context context) throws IOException {
if (inputType != null && FunctionTypeUtils.isMessage(inputType)) {
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
}
if (inputType != null && InputStream.class.isAssignableFrom(FunctionTypeUtils.getRawType(inputType))) {
MessageBuilder msgBuilder = MessageBuilder.withPayload(payload);
if (context != null) {
msgBuilder.setHeader(AWSLambdaUtils.AWS_CONTEXT, context);
}
return msgBuilder.build();
}
else {
return generateMessage(StreamUtils.copyToByteArray(payload), inputType, isSupplier, jsonMapper, context);
}
}
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
return generateMessage(payload, inputType, isSupplier, jsonMapper, null);
}
@@ -87,6 +108,7 @@ public final class AWSLambdaUtils {
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}
Object structMessage = jsonMapper.fromJson(payload, Object.class);
boolean isApiGateway = structMessage instanceof Map
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||

View File

@@ -80,7 +80,7 @@ public class FunctionInvoker implements RequestStreamHandler {
@Override
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
Message requestMessage = AWSLambdaUtils
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
.generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);

View File

@@ -57,6 +57,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.converter.AbstractMessageConverter;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MimeType;
import org.springframework.util.StreamUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
@@ -971,6 +972,40 @@ public class FunctionInvokerTests {
assertThat(result.get("body")).isEqualTo("\"boom\"");
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testApiGatewayInAndOutInputStream() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "echoInputStreamToString");
FunctionInvoker invoker = new FunctionInvoker();
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers).isNotEmpty();
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testApiGatewayInAndOutInputStreamMsg() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "echoInputStreamMsgToString");
FunctionInvoker invoker = new FunctionInvoker();
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers).isNotEmpty();
}
@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayInAndOut() throws Exception {
@@ -1400,6 +1435,34 @@ public class FunctionInvokerTests {
};
}
@Bean
public Function<InputStream, String> echoInputStreamToString() {
return is -> {
try {
String result = StreamUtils.copyToString(is, StandardCharsets.UTF_8);
return result;
}
catch (Exception e) {
throw new RuntimeException(e);
}
};
}
@Bean
public Function<Message<InputStream>, String> echoInputStreamMsgToString() {
return msg -> {
try {
String result = StreamUtils.copyToString(msg.getPayload(), StandardCharsets.UTF_8);
return result;
}
catch (Exception e) {
throw new RuntimeException(e);
}
};
}
@Bean
public Function<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> inputOutputApiEvent() {
return v -> {