diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java index f98c13b11..e74034776 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java @@ -16,20 +16,30 @@ package org.springframework.cloud.function.adapter.aws; +import java.io.IOException; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; +import java.util.Calendar; +import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.datatype.joda.JodaModule; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; -import org.springframework.cloud.function.json.JsonMapper; import org.springframework.http.HttpStatus; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; @@ -52,20 +62,30 @@ final class AWSLambdaUtils { } public static Message generateMessage(byte[] payload, MessageHeaders headers, - Type inputType, JsonMapper mapper) { - return generateMessage(payload, headers, inputType, mapper, null); + Type inputType, ObjectMapper objectMapper) { + return generateMessage(payload, headers, inputType, objectMapper, null); } @SuppressWarnings({ "unchecked", "rawtypes" }) public static Message generateMessage(byte[] payload, MessageHeaders headers, - Type inputType, JsonMapper mapper, @Nullable Context awsContext) { + Type inputType, ObjectMapper objectMapper, @Nullable Context awsContext) { + + if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) { + configureObjectMapper(objectMapper); + } if (logger.isInfoEnabled()) { logger.info("Incoming JSON Event: " + new String(payload)); } MessageBuilder messageBuilder = null; - Object request = mapper.fromJson(payload, Object.class); + Object request; + try { + request = objectMapper.readValue(payload, Object.class); + } + catch (Exception e) { + throw new IllegalStateException(e); + } if (FunctionTypeUtils.isMessage(inputType)) { inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0); } @@ -81,7 +101,7 @@ final class AWSLambdaUtils { else if (requestMap.containsKey("httpMethod")) { // API Gateway logger.info("Incoming request is API Gateway"); if (isTypeAnApiGatewayRequest(inputType)) { - APIGatewayProxyRequestEvent gatewayEvent = mapper.fromJson(requestMap, APIGatewayProxyRequestEvent.class); + APIGatewayProxyRequestEvent gatewayEvent = objectMapper.convertValue(requestMap, APIGatewayProxyRequestEvent.class); messageBuilder = MessageBuilder.withPayload(gatewayEvent); } else if (mapInputType) { @@ -89,7 +109,15 @@ final class AWSLambdaUtils { } else { Object body = requestMap.remove("body"); - body = body instanceof String ? String.valueOf(body).getBytes(StandardCharsets.UTF_8) : mapper.toJson(body); + try { + body = body instanceof String + ? String.valueOf(body).getBytes(StandardCharsets.UTF_8) + : objectMapper.writeValueAsBytes(body); + } + catch (Exception e) { + throw new IllegalStateException(e); + } + messageBuilder = MessageBuilder.withPayload(body).copyHeaders(requestMap); } } @@ -105,17 +133,24 @@ final class AWSLambdaUtils { @SuppressWarnings({ "rawtypes", "unchecked" }) public static byte[] generateOutput(Message requestMessage, Message responseMessage, - JsonMapper mapper) { - byte[] responseBytes = responseMessage.getPayload(); + ObjectMapper objectMapper) { + if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) { + configureObjectMapper(objectMapper); + } + byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes() : responseMessage.getPayload(); if (requestMessage.getHeaders().containsKey("httpMethod") - || isPayloadAnApiGatewayRequest(responseMessage.getPayload())) { // API Gateway + || isPayloadAnApiGatewayRequest(requestMessage.getPayload())) { // API Gateway Map response = new HashMap(); response.put("isBase64Encoded", false); - MessageHeaders headers = responseMessage.getHeaders(); - int statusCode = headers.containsKey("statusCode") - ? (int) headers.get("statusCode") - : 200; + AtomicReference headers = new AtomicReference<>(); + int statusCode = HttpStatus.OK.value(); + if (responseMessage != null) { + headers.set(responseMessage.getHeaders()); + statusCode = headers.get().containsKey("statusCode") + ? (int) headers.get().get("statusCode") + : HttpStatus.OK.value(); + } response.put("statusCode", statusCode); if (isRequestKinesis(requestMessage)) { @@ -123,19 +158,43 @@ final class AWSLambdaUtils { response.put("statusDescription", httpStatus.toString()); } - String body = new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "\""); + String body = responseMessage == null + ? "\"OK\"" : new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "\""); response.put("body", body); - Map responseHeaders = new HashMap<>(); - headers.keySet().forEach(key -> responseHeaders.put(key, headers.get(key).toString())); + if (responseMessage != null) { + Map responseHeaders = new HashMap<>(); + headers.get().keySet().forEach(key -> responseHeaders.put(key, headers.get().get(key).toString())); + response.put("headers", responseHeaders); + } - response.put("headers", responseHeaders); - responseBytes = mapper.toJson(response); + try { + responseBytes = objectMapper.writeValueAsBytes(response); + } + catch (Exception e) { + throw new IllegalStateException("Failed to serialize AWS Lambda output", e); + } } return responseBytes; } + private static void configureObjectMapper(ObjectMapper objectMapper) { + SimpleModule module = new SimpleModule(); + module.addDeserializer(Date.class, new JsonDeserializer() { + @Override + public Date deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) + throws IOException { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeInMillis(jsonParser.getValueAsLong()); + return calendar.getTime(); + } + }); + objectMapper.registerModule(module); + objectMapper.registerModule(new JodaModule()); + objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); + } + private static boolean isPayloadAnApiGatewayRequest(Object payload) { return isAPIGatewayProxyRequestEventPresent() ? payload instanceof APIGatewayProxyRequestEvent diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java index 29847c775..2cab7436b 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java @@ -24,6 +24,7 @@ import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -31,7 +32,6 @@ import org.springframework.boot.CommandLineRunner; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.cloud.function.context.FunctionCatalog; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; -import org.springframework.cloud.function.json.JsonMapper; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -45,6 +45,8 @@ import org.springframework.util.Assert; import org.springframework.web.client.RestTemplate; /** + * Event loop and necessary configurations to support AWS Lambda + * Custom Runtime - https://docs.aws.amazon.com/lambda/latest/dg/runtimes-custom.html. * * @author Oleg Zhurakousky * @since 3.1.1 @@ -82,7 +84,7 @@ public class CustomRuntimeEventLoop { RequestEntity requestEntity = RequestEntity.get(URI.create(eventUri)).build(); FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class); RestTemplate rest = new RestTemplate(); - JsonMapper mapper = context.getBean(JsonMapper.class); + ObjectMapper mapper = context.getBean(ObjectMapper.class); logger.info("Entering event loop"); while (true) { @@ -93,7 +95,6 @@ public class CustomRuntimeEventLoop { } FunctionInvocationWrapper function = locateFunction(functionCatalog, response.getHeaders().getContentType()); - Message eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8), fromHttp(response.getHeaders()), function.getInputType(), mapper); if (logger.isDebugEnabled()) { @@ -106,9 +107,8 @@ public class CustomRuntimeEventLoop { Message responseMessage = (Message) function.apply(eventMessage); - String reply = new String(responseMessage.getPayload(), StandardCharsets.UTF_8); - if (logger.isDebugEnabled()) { - logger.debug("Reply from function: " + reply); + if (responseMessage != null && logger.isDebugEnabled()) { + logger.debug("Reply from function: " + new String(responseMessage.getPayload(), StandardCharsets.UTF_8)); } byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper); diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java index f4bc69333..833c4b75d 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java @@ -1,5 +1,5 @@ /* - * Copyright 2019-2020 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,18 +19,12 @@ package org.springframework.cloud.function.adapter.aws; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; -import java.nio.charset.StandardCharsets; import java.util.Calendar; +import java.util.Collections; import java.util.Date; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; @@ -43,15 +37,12 @@ import org.apache.commons.logging.LogFactory; import org.springframework.boot.SpringApplication; import org.springframework.cloud.function.context.FunctionCatalog; -import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; import org.springframework.cloud.function.utils.FunctionClassUtils; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.env.Environment; -import org.springframework.http.HttpStatus; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; -import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.StreamUtils; import org.springframework.util.StringUtils; @@ -68,7 +59,7 @@ public class FunctionInvoker implements RequestStreamHandler { private static Log logger = LogFactory.getLog(FunctionInvoker.class); - private ObjectMapper mapper; + private ObjectMapper objectMapper; private FunctionInvocationWrapper function; @@ -79,50 +70,24 @@ public class FunctionInvoker implements RequestStreamHandler { @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException { - Message requestMessage = this.generateMessage(input, context); + final byte[] payload = StreamUtils.copyToByteArray(input); + Message requestMessage = AWSLambdaUtils + .generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.objectMapper, context); Message responseMessage = (Message) this.function.apply(requestMessage); - byte[] responseBytes = responseMessage.getPayload(); - if (requestMessage.getHeaders().containsKey("httpMethod") || requestMessage.getPayload() instanceof APIGatewayProxyRequestEvent) { // API Gateway - Map response = new HashMap(); - response.put("isBase64Encoded", false); - - MessageHeaders headers = responseMessage.getHeaders(); - int statusCode = headers.containsKey("statusCode") - ? (int) headers.get("statusCode") - : 200; - - response.put("statusCode", statusCode); - if (isKinesis(requestMessage)) { - HttpStatus httpStatus = HttpStatus.valueOf(statusCode); - response.put("statusDescription", httpStatus.toString()); - } - - String body = new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "\""); - response.put("body", body); - - Map responseHeaders = new HashMap<>(); - headers.keySet().forEach(key -> responseHeaders.put(key, headers.get(key).toString())); - - response.put("headers", responseHeaders); - responseBytes = mapper.writeValueAsBytes(response); - } + byte[] responseBytes = AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper); StreamUtils.copy(responseBytes, output); } - private boolean isKinesis(Message requestMessage) { - return requestMessage.getHeaders().containsKey("Records"); - } - private void start() { ConfigurableApplicationContext context = SpringApplication.run(FunctionClassUtils.getStartClass()); Environment environment = context.getEnvironment(); String functionName = environment.getProperty("spring.cloud.function.definition"); FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class); - this.mapper = context.getBean(ObjectMapper.class); - this.configureObjectMapper(); + this.objectMapper = context.getBean(ObjectMapper.class); + //this.configureObjectMapper(); if (logger.isInfoEnabled()) { logger.info("Locating function: '" + functionName + "'"); @@ -138,8 +103,6 @@ public class FunctionInvoker implements RequestStreamHandler { if (logger.isInfoEnabled()) { logger.info("Located function: '" + functionName + "'"); } - - mapper.registerModule(new JodaModule()); } private void configureObjectMapper() { @@ -153,79 +116,8 @@ public class FunctionInvoker implements RequestStreamHandler { return calendar.getTime(); } }); - mapper.registerModule(module); - mapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - private Message generateMessage(InputStream input, Context context) throws IOException { - final byte[] payload = StreamUtils.copyToByteArray(input); - - if (logger.isInfoEnabled()) { - logger.info("Incoming JSON for ApiGateway Event: " + new String(payload)); - } - - MessageBuilder messageBuilder = null; - Object request = this.mapper.readValue(payload, Object.class); - Type inputType = function.getInputType(); - if (FunctionTypeUtils.isMessage(inputType)) { - inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0); - } - boolean mapInputType = (inputType instanceof ParameterizedType && ((Class) ((ParameterizedType) inputType).getRawType()).isAssignableFrom(Map.class)); - if (request instanceof Map) { - Map requestMap = (Map) request; - if (requestMap.containsKey("Records")) { - List> records = (List>) requestMap.get("Records"); - Assert.notEmpty(records, "Incoming event has no records: " + requestMap); - this.logEvent(records); - messageBuilder = MessageBuilder.withPayload(payload); - } - else if (requestMap.containsKey("httpMethod")) { // API Gateway - logger.info("Incoming request is API Gateway"); - if (inputType.getTypeName().endsWith(APIGatewayProxyRequestEvent.class.getSimpleName())) { - APIGatewayProxyRequestEvent gatewayEvent = this.mapper.convertValue(requestMap, APIGatewayProxyRequestEvent.class); - messageBuilder = MessageBuilder.withPayload(gatewayEvent); - } - else if (mapInputType) { - messageBuilder = MessageBuilder.withPayload(requestMap).setHeader("httpMethod", requestMap.get("httpMethod")); - } - else { - Object body = requestMap.remove("body"); - body = body instanceof String ? String.valueOf(body).getBytes(StandardCharsets.UTF_8) : mapper.writeValueAsBytes(body); - messageBuilder = MessageBuilder.withPayload(body).copyHeaders(requestMap); - } - } - } - if (messageBuilder == null) { - messageBuilder = MessageBuilder.withPayload(payload); - } - return messageBuilder.setHeader("aws-context", context).build(); - } - - private void logEvent(List> records) { - if (this.isKinesisEvent(records.get(0))) { - logger.info("Incoming request is Kinesis Event"); - } - else if (this.isS3Event(records.get(0))) { - logger.info("Incoming request is S3 Event"); - } - else if (this.isSNSEvent(records.get(0))) { - logger.info("Incoming request is SNS Event"); - } - else { - logger.info("Incoming request is SQS Event"); - } - } - - private boolean isSNSEvent(Map record) { - return record.containsKey("Sns"); - } - - private boolean isS3Event(Map record) { - return record.containsKey("s3"); - } - - private boolean isKinesisEvent(Map record) { - return record.containsKey("kinesis"); + this.objectMapper.registerModule(module); + this.objectMapper.registerModule(new JodaModule()); + this.objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); } } diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java index 6a3a85284..f60dbdc8d 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/test/java/org/springframework/cloud/function/adapter/aws/FunctionInvokerTests.java @@ -21,6 +21,7 @@ import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Function; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; @@ -656,6 +657,21 @@ public class FunctionInvokerTests { assertThat(result.get("body")).isEqualTo("\"hello\""); } + @SuppressWarnings("rawtypes") + @Test + public void testApiGatewayEventConsumer() throws Exception { + System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName()); + System.setProperty("spring.cloud.function.definition", "consume"); + FunctionInvoker invoker = new FunctionInvoker(); + + InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes()); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + invoker.handleRequest(targetStream, output, null); + + Map result = mapper.readValue(output.toByteArray(), Map.class); + assertThat(result.get("body")).isEqualTo("\"OK\""); + } + @EnableAutoConfiguration @Configuration public static class KinesisConfiguration { @@ -823,6 +839,12 @@ public class FunctionInvokerTests { @EnableAutoConfiguration @Configuration public static class ApiGatewayConfiguration { + + @Bean + public Consumer consume() { + return v -> System.out.println(v); + } + @Bean public Function uppercase() { return v -> v.toUpperCase(); diff --git a/spring-cloud-function-samples/function-sample-aws-custom-bean/src/main/java/com/example/LambdaApplication.java b/spring-cloud-function-samples/function-sample-aws-custom-bean/src/main/java/com/example/LambdaApplication.java index 8a6f62d61..7fb2ecec8 100644 --- a/spring-cloud-function-samples/function-sample-aws-custom-bean/src/main/java/com/example/LambdaApplication.java +++ b/spring-cloud-function-samples/function-sample-aws-custom-bean/src/main/java/com/example/LambdaApplication.java @@ -1,6 +1,7 @@ package com.example; import java.util.Arrays; +import java.util.function.Consumer; import java.util.function.Function; import org.apache.commons.logging.Log; @@ -18,6 +19,13 @@ public class LambdaApplication { private static Log logger = LogFactory.getLog(LambdaApplication.class); + @Bean + public Consumer consume() { + return value -> { + logger.info("Consuming: " + value); + }; + } + @Bean public Function uppercase() { return value -> { @@ -53,8 +61,6 @@ public class LambdaApplication { public static void main(String[] args) { - System.out.println("=====> ENVIRONMENT: " + System.getenv("AWS_LAMBDA_RUNTIME_API")); - //FunctionalSpringApplication.run(LambdaApplication.class, args); logger.info("==> Starting: LambdaApplication"); if (!ObjectUtils.isEmpty(args)) { logger.info("==> args: " + Arrays.asList(args)); @@ -62,10 +68,4 @@ public class LambdaApplication { SpringApplication.run(LambdaApplication.class, args); } -// @Override -// public void initialize(GenericApplicationContext context) { -// context.registerBean("uppercase", FunctionRegistration.class, -// () -> new FunctionRegistration<>(uppercase()).type( -// FunctionType.from(String.class).to(String.class))); -// } }