diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java index 35dd5a537..6e6fe4441 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSLambdaUtils.java @@ -16,26 +16,18 @@ package org.springframework.cloud.function.adapter.aws; -import java.io.ByteArrayInputStream; -import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; -import com.amazonaws.services.lambda.runtime.Context; -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; -import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent; -import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer; -import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.json.JsonMapper; import org.springframework.http.HttpStatus; -import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.MessageBuilder; @@ -51,18 +43,18 @@ final class AWSLambdaUtils { static final String AWS_API_GATEWAY = "aws-api-gateway"; + static final String AWS_EVENT = "aws-event"; + public static final String AWS_CONTEXT = "aws-context"; private AWSLambdaUtils() { } - public static Message generateMessage(byte[] payload, MessageHeaders headers, - Type inputType, JsonMapper objectMapper) { - return generateMessage(payload, headers, inputType, objectMapper, null); - } - static boolean isSupportedAWSType(Type inputType) { + if (FunctionTypeUtils.isMessage(inputType)) { + inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0); + } String typeName = inputType.getTypeName(); return typeName.equals("com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent") || typeName.equals("com.amazonaws.services.lambda.runtime.events.S3Event") @@ -74,93 +66,30 @@ final class AWSLambdaUtils { } @SuppressWarnings({ "unchecked", "rawtypes" }) - public static Message generateMessage(byte[] payload, MessageHeaders headers, - Type inputType, JsonMapper objectMapper, @Nullable Context awsContext) { - + public static Message generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) { if (logger.isInfoEnabled()) { - logger.info("Incoming JSON Event: " + new String(payload)); + logger.info("Received: " + new String(payload, StandardCharsets.UTF_8)); } - if (FunctionTypeUtils.isMessage(inputType)) { - inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0); - } + Object structMessage = jsonMapper.fromJson(payload, Object.class); + boolean isApiGateway = structMessage instanceof Map + && (((Map) structMessage).containsKey("httpMethod") || + (((Map) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version"))); - MessageBuilder messageBuilder = null; - if (inputType != null && isSupportedAWSType(inputType)) { - PojoSerializer serializer = LambdaEventSerializers.serializerFor(FunctionTypeUtils.getRawType(inputType), Thread.currentThread().getContextClassLoader()); - Object event = serializer.fromJson(new ByteArrayInputStream(payload)); - messageBuilder = MessageBuilder.withPayload(event); - if (event instanceof APIGatewayProxyRequestEvent || event instanceof APIGatewayV2HTTPEvent) { - messageBuilder.setHeader(AWS_API_GATEWAY, true); - logger.info("Incoming request is API Gateway"); - } + Message requestMessage; + MessageBuilder builder = MessageBuilder.withPayload(payload); + if (isApiGateway) { + builder.setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true); } - else { - Object request; - try { - request = objectMapper.fromJson(payload, Object.class); - } - catch (Exception e) { - throw new IllegalStateException(e); - } - - if (request instanceof Map) { - logger.info("Incoming MAP: " + request); - if (((Map) request).containsKey("httpMethod")) { //API Gateway - logger.info("Incoming request is API Gateway"); - boolean mapInputType = (inputType instanceof ParameterizedType && ((Class) ((ParameterizedType) inputType).getRawType()).isAssignableFrom(Map.class)); - if (mapInputType) { - messageBuilder = MessageBuilder.withPayload(request).setHeader("httpMethod", ((Map) request).get("httpMethod")); - messageBuilder.setHeader(AWS_API_GATEWAY, true); - } - else { - messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request); - } - } - else if ((((Map) request).containsKey("routeKey") && ((Map) request).containsKey("version"))) { - logger.info("Incoming request is API Gateway v2.0"); - messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request); - } - Object providedHeaders = ((Map) request).get("headers"); - if (providedHeaders != null && providedHeaders instanceof Map) { - messageBuilder = MessageBuilder.withPayload(request); - messageBuilder.removeHeader("headers"); - messageBuilder.copyHeaders((Map) providedHeaders); - } - } - else if (request instanceof Iterable) { - messageBuilder = MessageBuilder.withPayload(request); - } + if (!isSupplier && AWSLambdaUtils.isSupportedAWSType(inputType)) { + builder.setHeader(AWSLambdaUtils.AWS_EVENT, true); } - - - if (messageBuilder == null) { - messageBuilder = MessageBuilder.withPayload(payload); + // + if (structMessage instanceof Map && ((Map) structMessage).containsKey("headers")) { + builder.copyHeaders((Map) ((Map) structMessage).get("headers")); } - if (awsContext != null) { - messageBuilder.setHeader(AWS_CONTEXT, awsContext); - } - logger.info("Incoming request headers: " + headers); - - return messageBuilder.copyHeaders(headers).build(); - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private static MessageBuilder createMessageBuilderForPOJOFunction(JsonMapper objectMapper, Map request) { - Object body = request.remove("body"); - try { - body = body instanceof String - ? String.valueOf(body).getBytes(StandardCharsets.UTF_8) - : objectMapper.toJson(body); - } - catch (Exception e) { - throw new IllegalStateException(e); - } - logger.info("Body is " + body); - - MessageBuilder messageBuilder = MessageBuilder.withPayload(body).copyHeaders(request); - messageBuilder.setHeader(AWS_API_GATEWAY, true); - return messageBuilder; + requestMessage = builder.build(); + return requestMessage; } private static byte[] extractPayload(Message msg, JsonMapper objectMapper) { diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java index 76daaf19a..0fd187145 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/AWSTypesMessageConverter.java @@ -58,6 +58,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter { if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_API_GATEWAY) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_API_GATEWAY))) { return true; } + if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_EVENT) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_EVENT))) { + return true; + } return false; } diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java index 43cec472d..d67a2edea 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/CustomRuntimeEventLoop.java @@ -1,5 +1,5 @@ /* - * Copyright 2021-2021 the original author or authors. + * Copyright 2021-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,7 @@ import java.net.SocketException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.text.MessageFormat; -import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -45,7 +42,6 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.web.client.RestTemplate; @@ -137,9 +133,11 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle { } if (response != null) { - FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders().getContentType()); - Message eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8), - fromHttp(response.getHeaders()), function.getInputType(), mapper); + FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders()); + + Message eventMessage = AWSLambdaUtils + .generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8), function.getInputType(), function.isSupplier(), mapper); + if (logger.isDebugEnabled()) { logger.debug("Event message: " + eventMessage); } @@ -206,7 +204,8 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle { return null; } - private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, MediaType contentType) { + private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, HttpHeaders httpHeaders) { + MediaType contentType = httpHeaders.getContentType(); String handlerName = environment.getProperty("DEFAULT_HANDLER"); if (logger.isDebugEnabled()) { logger.debug("Value of DEFAULT_HANDLER env: " + handlerName); @@ -235,6 +234,15 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle { function = functionCatalog.lookup(handlerName, contentType.toString()); } + if (function == null) { + logger.info("Could not determine DEFAULT_HANDLER, _HANDLER or 'spring.cloud.function.definition'"); + handlerName = httpHeaders.getFirst("spring.cloud.function.definition"); + if (logger.isDebugEnabled()) { + logger.debug("Value of 'spring.cloud.function.definition' header: " + handlerName); + } + function = functionCatalog.lookup(handlerName, contentType.toString()); + } + Assert.notNull(function, "Failed to locate function. Tried locating default function, " + "function by 'DEFAULT_HANDLER', '_HANDLER' env variable as well as'spring.cloud.function.definition'. " + "Functions available in catalog are: " + functionCatalog.getNames(null)); @@ -244,25 +252,6 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle { return function; } - private MessageHeaders fromHttp(HttpHeaders headers) { - Map map = new LinkedHashMap<>(); - for (String name : headers.keySet()) { - Collection values = multi(headers.get(name)); - name = name.toLowerCase(); - Object value = values == null ? null - : (values.size() == 1 ? values.iterator().next() : values); - if (name.toLowerCase().equals(HttpHeaders.CONTENT_TYPE.toLowerCase())) { - name = MessageHeaders.CONTENT_TYPE; - } - map.put(name, value); - } - return new MessageHeaders(map); - } - - private Collection multi(Object value) { - return value instanceof Collection ? (Collection) value : Arrays.asList(value); - } - private static String extractVersion() { String path = CustomRuntimeEventLoop.class.getProtectionDomain().getCodeSource().getLocation().toString(); int endIndex = path.lastIndexOf('.'); diff --git a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java index 071a2c6a4..3d51ffe81 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java +++ b/spring-cloud-function-adapters/spring-cloud-function-adapter-aws/src/main/java/org/springframework/cloud/function/adapter/aws/FunctionInvoker.java @@ -1,5 +1,5 @@ /* - * Copyright 2019-2021 the original author or authors. + * Copyright 2019-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,13 +19,10 @@ package org.springframework.cloud.function.adapter.aws; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Calendar; -import java.util.Collections; import java.util.Date; import java.util.List; -import java.util.Map; import java.util.Set; import com.amazonaws.services.lambda.runtime.Context; @@ -53,7 +50,6 @@ import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.env.Environment; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -80,35 +76,11 @@ public class FunctionInvoker implements RequestStreamHandler { this.start(); } - @SuppressWarnings({ "rawtypes", "unchecked" }) + @SuppressWarnings({ "rawtypes" }) @Override public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException { - final byte[] payload = StreamUtils.copyToByteArray(input); - - if (logger.isInfoEnabled()) { - logger.info("Received: " + new String(payload, StandardCharsets.UTF_8)); - } - - Object structMessage = this.jsonMapper.fromJson(payload, Object.class); - - boolean isApiGateway = structMessage instanceof Map - && (((Map) structMessage).containsKey("httpMethod") || - (((Map) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version"))); - - - // TODO we should eventually completely delegate to message converter - Message requestMessage; - if (isApiGateway) { - MessageBuilder builder = MessageBuilder.withPayload(payload).setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true); - if (structMessage instanceof Map && ((Map) structMessage).containsKey("headers")) { - builder.copyHeaders((Map) ((Map) structMessage).get("headers")); - } - requestMessage = builder.build(); - } - else { - requestMessage = AWSLambdaUtils - .generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.jsonMapper, context); - } + Message requestMessage = AWSLambdaUtils + .generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper); Object response = this.function.apply(requestMessage); byte[] responseBytes = this.buildResult(requestMessage, response);