diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java index 9a4ef4d18..88b59826e 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java @@ -21,7 +21,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Calendar; -import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; @@ -44,6 +43,7 @@ import org.springframework.cloud.function.context.catalog.FunctionInspector; import org.springframework.cloud.function.utils.FunctionClassUtils; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.env.Environment; +import org.springframework.http.HttpStatus; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; @@ -78,12 +78,21 @@ public class FunctionInvoker implements RequestStreamHandler { Message responseMessage = this.function.apply(requestMessage); byte[] responseBytes = responseMessage.getPayload(); - if (requestMessage.getHeaders().containsKey("httpMethod")) { + Map requestPayloadMap = this.getRequestPayloadAsMap(requestMessage); + if (requestPayloadMap != null && requestPayloadMap.containsKey("httpMethod")) { Map response = new HashMap(); response.put("isBase64Encoded", false); - response.put("statusCode", 200); + + int statusCode = responseMessage.getHeaders().containsKey("statusCode") + ? (int) responseMessage.getHeaders().get("statusCode") + : 200; + + HttpStatus httpStatus = HttpStatus.valueOf(statusCode); + + response.put("statusCode", statusCode); + response.put("statusDescription", httpStatus.toString()); response.put("body", new String(responseMessage.getPayload(), StandardCharsets.UTF_8)); - response.put("headers", Collections.singletonMap("foo", "bar")); + response.put("headers", responseMessage.getHeaders()); responseBytes = mapper.writeValueAsBytes(response); } @@ -138,4 +147,15 @@ public class FunctionInvoker implements RequestStreamHandler { return message; } + + @SuppressWarnings("unchecked") + private Map getRequestPayloadAsMap(Message message) { + try { + return this.mapper.readValue(message.getPayload(), Map.class); + } + catch (Exception e) { + // ignore + } + return null; + } } diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java index 448c14aec..526b863e5 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java @@ -20,9 +20,11 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.util.Map; import java.util.function.Function; import com.amazonaws.services.lambda.runtime.events.KinesisEvent; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -39,7 +41,31 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class FunctionInvokerTests { - String sampleEvent = "{" + + String sampleLBEvent = "{" + + " \"requestContext\": {" + + " \"elb\": {" + + " \"targetGroupArn\": \"arn:aws:elasticloadbalancing:region:123456789012:targetgroup/my-target-group/6d0ecf831eec9f09\"" + + " }" + + " }," + + " \"httpMethod\": \"GET\"," + + " \"path\": \"/\"," + + " \"headers\": {" + + " \"accept\": \"text/html,application/xhtml+xml\"," + + " \"accept-language\": \"en-US,en;q=0.8\"," + + " \"content-type\": \"text/plain\"," + + " \"cookie\": \"cookies\"," + + " \"host\": \"lambda-846800462-us-east-2.elb.amazonaws.com\"," + + " \"user-agent\": \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6)\"," + + " \"x-amzn-trace-id\": \"Root=1-5bdb40ca-556d8b0c50dc66f0511bf520\"," + + " \"x-forwarded-for\": \"72.21.198.66\"," + + " \"x-forwarded-port\": \"443\"," + + " \"x-forwarded-proto\": \"https\"" + + " }," + + " \"isBase64Encoded\": false," + + " \"body\": \"request_body\"" + + "}"; + + String sampleKinesisEvent = "{" + " \"Records\": [" + " {" + " \"kinesis\": {" + @@ -76,33 +102,52 @@ public class FunctionInvokerTests { " ]" + "}"; + @SuppressWarnings("rawtypes") @Test - public void testKinesisStringMessageEvent() throws Exception { - System.setProperty("MAIN_CLASS", KinesisConfiguration.class.getName()); + public void testLBStringMessageEvent() throws Exception { + System.setProperty("MAIN_CLASS", GenericConfiguration.class.getName()); System.setProperty("spring.cloud.function.definition", "echoStringMessage"); FunctionInvoker invoker = new FunctionInvoker(); - InputStream targetStream = new ByteArrayInputStream(this.sampleEvent.getBytes()); + InputStream targetStream = new ByteArrayInputStream(this.sampleLBEvent.getBytes()); ByteArrayOutputStream output = new ByteArrayOutputStream(); invoker.handleRequest(targetStream, output, null); String result = new String(output.toByteArray(), StandardCharsets.UTF_8); - assertThat(result).isEqualTo(this.sampleEvent); + + ObjectMapper mapper = new ObjectMapper(); + Map responseMap = mapper.readValue(result, Map.class); + assertThat(responseMap.get("statusCode")).isEqualTo(200); + assertThat(responseMap.get("statusDescription")).isEqualTo("200 OK"); + } + + @Test + public void testKinesisStringMessageEvent() throws Exception { + System.setProperty("MAIN_CLASS", GenericConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "echoStringMessage"); + FunctionInvoker invoker = new FunctionInvoker(); + + InputStream targetStream = new ByteArrayInputStream(this.sampleKinesisEvent.getBytes()); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + invoker.handleRequest(targetStream, output, null); + + String result = new String(output.toByteArray(), StandardCharsets.UTF_8); + assertThat(result).isEqualTo(this.sampleKinesisEvent); } @Test public void testKinesisStringEvent() throws Exception { - System.setProperty("MAIN_CLASS", KinesisConfiguration.class.getName()); - System.setProperty("spring.cloud.function.definition", "echoStringMessage"); + System.setProperty("MAIN_CLASS", GenericConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "echoString"); FunctionInvoker invoker = new FunctionInvoker(); - InputStream targetStream = new ByteArrayInputStream(this.sampleEvent.getBytes()); + InputStream targetStream = new ByteArrayInputStream(this.sampleKinesisEvent.getBytes()); ByteArrayOutputStream output = new ByteArrayOutputStream(); invoker.handleRequest(targetStream, output, null); String result = new String(output.toByteArray(), StandardCharsets.UTF_8); System.out.println(result); - assertThat(result).isEqualTo(this.sampleEvent); + assertThat(result).isEqualTo(this.sampleKinesisEvent); } @@ -112,7 +157,7 @@ public class FunctionInvokerTests { System.setProperty("spring.cloud.function.definition", "echoKinesisEvent"); FunctionInvoker invoker = new FunctionInvoker(); - InputStream targetStream = new ByteArrayInputStream(this.sampleEvent.getBytes()); + InputStream targetStream = new ByteArrayInputStream(this.sampleKinesisEvent.getBytes()); ByteArrayOutputStream output = new ByteArrayOutputStream(); invoker.handleRequest(targetStream, output, null); @@ -121,11 +166,9 @@ public class FunctionInvokerTests { assertThat(result).contains("\"sequenceNumber\":\"49590338271490256608559692538361571095921575989136588898\""); } - - @EnableAutoConfiguration @Configuration - public static class KinesisConfiguration { + public static class GenericConfiguration { @Bean public Function, Message> echoStringMessage() { @@ -142,6 +185,12 @@ public class FunctionInvokerTests { return v; }; } + } + + + @EnableAutoConfiguration + @Configuration + public static class KinesisConfiguration { @Bean public Function echoKinesisEvent() {