Reworked AWSUtils and related classes to delegate type conversion to AWSTypesMessageConverter
Resolves #889 polish
This commit is contained in:
@@ -16,26 +16,18 @@
|
||||
|
||||
package org.springframework.cloud.function.adapter.aws;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.lang.reflect.Type;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import com.amazonaws.services.lambda.runtime.Context;
|
||||
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
|
||||
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
|
||||
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
|
||||
import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
|
||||
import org.springframework.cloud.function.json.JsonMapper;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
@@ -51,18 +43,18 @@ final class AWSLambdaUtils {
|
||||
|
||||
static final String AWS_API_GATEWAY = "aws-api-gateway";
|
||||
|
||||
static final String AWS_EVENT = "aws-event";
|
||||
|
||||
public static final String AWS_CONTEXT = "aws-context";
|
||||
|
||||
private AWSLambdaUtils() {
|
||||
|
||||
}
|
||||
|
||||
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
|
||||
Type inputType, JsonMapper objectMapper) {
|
||||
return generateMessage(payload, headers, inputType, objectMapper, null);
|
||||
}
|
||||
|
||||
static boolean isSupportedAWSType(Type inputType) {
|
||||
if (FunctionTypeUtils.isMessage(inputType)) {
|
||||
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
|
||||
}
|
||||
String typeName = inputType.getTypeName();
|
||||
return typeName.equals("com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent")
|
||||
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.S3Event")
|
||||
@@ -74,93 +66,30 @@ final class AWSLambdaUtils {
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
|
||||
Type inputType, JsonMapper objectMapper, @Nullable Context awsContext) {
|
||||
|
||||
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("Incoming JSON Event: " + new String(payload));
|
||||
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
if (FunctionTypeUtils.isMessage(inputType)) {
|
||||
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
|
||||
}
|
||||
Object structMessage = jsonMapper.fromJson(payload, Object.class);
|
||||
boolean isApiGateway = structMessage instanceof Map
|
||||
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||
|
||||
(((Map<String, Object>) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version")));
|
||||
|
||||
MessageBuilder messageBuilder = null;
|
||||
if (inputType != null && isSupportedAWSType(inputType)) {
|
||||
PojoSerializer<?> serializer = LambdaEventSerializers.serializerFor(FunctionTypeUtils.getRawType(inputType), Thread.currentThread().getContextClassLoader());
|
||||
Object event = serializer.fromJson(new ByteArrayInputStream(payload));
|
||||
messageBuilder = MessageBuilder.withPayload(event);
|
||||
if (event instanceof APIGatewayProxyRequestEvent || event instanceof APIGatewayV2HTTPEvent) {
|
||||
messageBuilder.setHeader(AWS_API_GATEWAY, true);
|
||||
logger.info("Incoming request is API Gateway");
|
||||
}
|
||||
Message<byte[]> requestMessage;
|
||||
MessageBuilder<byte[]> builder = MessageBuilder.withPayload(payload);
|
||||
if (isApiGateway) {
|
||||
builder.setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true);
|
||||
}
|
||||
else {
|
||||
Object request;
|
||||
try {
|
||||
request = objectMapper.fromJson(payload, Object.class);
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
|
||||
if (request instanceof Map) {
|
||||
logger.info("Incoming MAP: " + request);
|
||||
if (((Map) request).containsKey("httpMethod")) { //API Gateway
|
||||
logger.info("Incoming request is API Gateway");
|
||||
boolean mapInputType = (inputType instanceof ParameterizedType && ((Class<?>) ((ParameterizedType) inputType).getRawType()).isAssignableFrom(Map.class));
|
||||
if (mapInputType) {
|
||||
messageBuilder = MessageBuilder.withPayload(request).setHeader("httpMethod", ((Map) request).get("httpMethod"));
|
||||
messageBuilder.setHeader(AWS_API_GATEWAY, true);
|
||||
}
|
||||
else {
|
||||
messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request);
|
||||
}
|
||||
}
|
||||
else if ((((Map) request).containsKey("routeKey") && ((Map) request).containsKey("version"))) {
|
||||
logger.info("Incoming request is API Gateway v2.0");
|
||||
messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request);
|
||||
}
|
||||
Object providedHeaders = ((Map) request).get("headers");
|
||||
if (providedHeaders != null && providedHeaders instanceof Map) {
|
||||
messageBuilder = MessageBuilder.withPayload(request);
|
||||
messageBuilder.removeHeader("headers");
|
||||
messageBuilder.copyHeaders((Map<String, Object>) providedHeaders);
|
||||
}
|
||||
}
|
||||
else if (request instanceof Iterable) {
|
||||
messageBuilder = MessageBuilder.withPayload(request);
|
||||
}
|
||||
if (!isSupplier && AWSLambdaUtils.isSupportedAWSType(inputType)) {
|
||||
builder.setHeader(AWSLambdaUtils.AWS_EVENT, true);
|
||||
}
|
||||
|
||||
|
||||
if (messageBuilder == null) {
|
||||
messageBuilder = MessageBuilder.withPayload(payload);
|
||||
//
|
||||
if (structMessage instanceof Map && ((Map<String, Object>) structMessage).containsKey("headers")) {
|
||||
builder.copyHeaders((Map<String, Object>) ((Map<String, Object>) structMessage).get("headers"));
|
||||
}
|
||||
if (awsContext != null) {
|
||||
messageBuilder.setHeader(AWS_CONTEXT, awsContext);
|
||||
}
|
||||
logger.info("Incoming request headers: " + headers);
|
||||
|
||||
return messageBuilder.copyHeaders(headers).build();
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "rawtypes", "unchecked" })
|
||||
private static MessageBuilder createMessageBuilderForPOJOFunction(JsonMapper objectMapper, Map request) {
|
||||
Object body = request.remove("body");
|
||||
try {
|
||||
body = body instanceof String
|
||||
? String.valueOf(body).getBytes(StandardCharsets.UTF_8)
|
||||
: objectMapper.toJson(body);
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
logger.info("Body is " + body);
|
||||
|
||||
MessageBuilder messageBuilder = MessageBuilder.withPayload(body).copyHeaders(request);
|
||||
messageBuilder.setHeader(AWS_API_GATEWAY, true);
|
||||
return messageBuilder;
|
||||
requestMessage = builder.build();
|
||||
return requestMessage;
|
||||
}
|
||||
|
||||
private static byte[] extractPayload(Message<Object> msg, JsonMapper objectMapper) {
|
||||
|
||||
@@ -58,6 +58,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter {
|
||||
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_API_GATEWAY) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_API_GATEWAY))) {
|
||||
return true;
|
||||
}
|
||||
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_EVENT) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_EVENT))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2021-2021 the original author or authors.
|
||||
* Copyright 2021-2022 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -22,10 +22,7 @@ import java.net.SocketException;
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
@@ -45,7 +42,6 @@ import org.springframework.http.MediaType;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@@ -137,9 +133,11 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
|
||||
}
|
||||
|
||||
if (response != null) {
|
||||
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders().getContentType());
|
||||
Message<byte[]> eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8),
|
||||
fromHttp(response.getHeaders()), function.getInputType(), mapper);
|
||||
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders());
|
||||
|
||||
Message<byte[]> eventMessage = AWSLambdaUtils
|
||||
.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8), function.getInputType(), function.isSupplier(), mapper);
|
||||
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Event message: " + eventMessage);
|
||||
}
|
||||
@@ -206,7 +204,8 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
|
||||
return null;
|
||||
}
|
||||
|
||||
private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, MediaType contentType) {
|
||||
private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, HttpHeaders httpHeaders) {
|
||||
MediaType contentType = httpHeaders.getContentType();
|
||||
String handlerName = environment.getProperty("DEFAULT_HANDLER");
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Value of DEFAULT_HANDLER env: " + handlerName);
|
||||
@@ -235,6 +234,15 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
|
||||
function = functionCatalog.lookup(handlerName, contentType.toString());
|
||||
}
|
||||
|
||||
if (function == null) {
|
||||
logger.info("Could not determine DEFAULT_HANDLER, _HANDLER or 'spring.cloud.function.definition'");
|
||||
handlerName = httpHeaders.getFirst("spring.cloud.function.definition");
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Value of 'spring.cloud.function.definition' header: " + handlerName);
|
||||
}
|
||||
function = functionCatalog.lookup(handlerName, contentType.toString());
|
||||
}
|
||||
|
||||
Assert.notNull(function, "Failed to locate function. Tried locating default function, "
|
||||
+ "function by 'DEFAULT_HANDLER', '_HANDLER' env variable as well as'spring.cloud.function.definition'. "
|
||||
+ "Functions available in catalog are: " + functionCatalog.getNames(null));
|
||||
@@ -244,25 +252,6 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
|
||||
return function;
|
||||
}
|
||||
|
||||
private MessageHeaders fromHttp(HttpHeaders headers) {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
for (String name : headers.keySet()) {
|
||||
Collection<?> values = multi(headers.get(name));
|
||||
name = name.toLowerCase();
|
||||
Object value = values == null ? null
|
||||
: (values.size() == 1 ? values.iterator().next() : values);
|
||||
if (name.toLowerCase().equals(HttpHeaders.CONTENT_TYPE.toLowerCase())) {
|
||||
name = MessageHeaders.CONTENT_TYPE;
|
||||
}
|
||||
map.put(name, value);
|
||||
}
|
||||
return new MessageHeaders(map);
|
||||
}
|
||||
|
||||
private Collection<?> multi(Object value) {
|
||||
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
|
||||
}
|
||||
|
||||
private static String extractVersion() {
|
||||
String path = CustomRuntimeEventLoop.class.getProtectionDomain().getCodeSource().getLocation().toString();
|
||||
int endIndex = path.lastIndexOf('.');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2019-2021 the original author or authors.
|
||||
* Copyright 2019-2022 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -19,13 +19,10 @@ package org.springframework.cloud.function.adapter.aws;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Calendar;
|
||||
import java.util.Collections;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import com.amazonaws.services.lambda.runtime.Context;
|
||||
@@ -53,7 +50,6 @@ import org.springframework.context.ApplicationContextInitializer;
|
||||
import org.springframework.context.ConfigurableApplicationContext;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -80,35 +76,11 @@ public class FunctionInvoker implements RequestStreamHandler {
|
||||
this.start();
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "rawtypes", "unchecked" })
|
||||
@SuppressWarnings({ "rawtypes" })
|
||||
@Override
|
||||
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
|
||||
final byte[] payload = StreamUtils.copyToByteArray(input);
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
Object structMessage = this.jsonMapper.fromJson(payload, Object.class);
|
||||
|
||||
boolean isApiGateway = structMessage instanceof Map
|
||||
&& (((Map) structMessage).containsKey("httpMethod") ||
|
||||
(((Map) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version")));
|
||||
|
||||
|
||||
// TODO we should eventually completely delegate to message converter
|
||||
Message requestMessage;
|
||||
if (isApiGateway) {
|
||||
MessageBuilder builder = MessageBuilder.withPayload(payload).setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true);
|
||||
if (structMessage instanceof Map && ((Map) structMessage).containsKey("headers")) {
|
||||
builder.copyHeaders((Map) ((Map) structMessage).get("headers"));
|
||||
}
|
||||
requestMessage = builder.build();
|
||||
}
|
||||
else {
|
||||
requestMessage = AWSLambdaUtils
|
||||
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.jsonMapper, context);
|
||||
}
|
||||
Message requestMessage = AWSLambdaUtils
|
||||
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper);
|
||||
|
||||
Object response = this.function.apply(requestMessage);
|
||||
byte[] responseBytes = this.buildResult(requestMessage, response);
|
||||
|
||||
Reference in New Issue
Block a user