Reworked AWSUtils and related classes to delegate type conversion to AWSTypesMessageConverter

Resolves #889

polish
This commit is contained in:
Oleg Zhurakousky
2022-08-03 15:22:21 +02:00
parent fa1a7a98a3
commit 8bcdeb5cc2
4 changed files with 46 additions and 153 deletions

View File

@@ -16,26 +16,18 @@
package org.springframework.cloud.function.adapter.aws;
import java.io.ByteArrayInputStream;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.http.HttpStatus;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
@@ -51,18 +43,18 @@ final class AWSLambdaUtils {
static final String AWS_API_GATEWAY = "aws-api-gateway";
static final String AWS_EVENT = "aws-event";
public static final String AWS_CONTEXT = "aws-context";
private AWSLambdaUtils() {
}
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
Type inputType, JsonMapper objectMapper) {
return generateMessage(payload, headers, inputType, objectMapper, null);
}
static boolean isSupportedAWSType(Type inputType) {
if (FunctionTypeUtils.isMessage(inputType)) {
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
}
String typeName = inputType.getTypeName();
return typeName.equals("com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent")
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.S3Event")
@@ -74,93 +66,30 @@ final class AWSLambdaUtils {
}
@SuppressWarnings({ "unchecked", "rawtypes" })
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
Type inputType, JsonMapper objectMapper, @Nullable Context awsContext) {
public static Message<byte[]> generateMessage(byte[] payload, Type inputType, boolean isSupplier, JsonMapper jsonMapper) {
if (logger.isInfoEnabled()) {
logger.info("Incoming JSON Event: " + new String(payload));
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}
if (FunctionTypeUtils.isMessage(inputType)) {
inputType = FunctionTypeUtils.getImmediateGenericType(inputType, 0);
}
Object structMessage = jsonMapper.fromJson(payload, Object.class);
boolean isApiGateway = structMessage instanceof Map
&& (((Map<String, Object>) structMessage).containsKey("httpMethod") ||
(((Map<String, Object>) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version")));
MessageBuilder messageBuilder = null;
if (inputType != null && isSupportedAWSType(inputType)) {
PojoSerializer<?> serializer = LambdaEventSerializers.serializerFor(FunctionTypeUtils.getRawType(inputType), Thread.currentThread().getContextClassLoader());
Object event = serializer.fromJson(new ByteArrayInputStream(payload));
messageBuilder = MessageBuilder.withPayload(event);
if (event instanceof APIGatewayProxyRequestEvent || event instanceof APIGatewayV2HTTPEvent) {
messageBuilder.setHeader(AWS_API_GATEWAY, true);
logger.info("Incoming request is API Gateway");
}
Message<byte[]> requestMessage;
MessageBuilder<byte[]> builder = MessageBuilder.withPayload(payload);
if (isApiGateway) {
builder.setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true);
}
else {
Object request;
try {
request = objectMapper.fromJson(payload, Object.class);
}
catch (Exception e) {
throw new IllegalStateException(e);
}
if (request instanceof Map) {
logger.info("Incoming MAP: " + request);
if (((Map) request).containsKey("httpMethod")) { //API Gateway
logger.info("Incoming request is API Gateway");
boolean mapInputType = (inputType instanceof ParameterizedType && ((Class<?>) ((ParameterizedType) inputType).getRawType()).isAssignableFrom(Map.class));
if (mapInputType) {
messageBuilder = MessageBuilder.withPayload(request).setHeader("httpMethod", ((Map) request).get("httpMethod"));
messageBuilder.setHeader(AWS_API_GATEWAY, true);
}
else {
messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request);
}
}
else if ((((Map) request).containsKey("routeKey") && ((Map) request).containsKey("version"))) {
logger.info("Incoming request is API Gateway v2.0");
messageBuilder = createMessageBuilderForPOJOFunction(objectMapper, (Map) request);
}
Object providedHeaders = ((Map) request).get("headers");
if (providedHeaders != null && providedHeaders instanceof Map) {
messageBuilder = MessageBuilder.withPayload(request);
messageBuilder.removeHeader("headers");
messageBuilder.copyHeaders((Map<String, Object>) providedHeaders);
}
}
else if (request instanceof Iterable) {
messageBuilder = MessageBuilder.withPayload(request);
}
if (!isSupplier && AWSLambdaUtils.isSupportedAWSType(inputType)) {
builder.setHeader(AWSLambdaUtils.AWS_EVENT, true);
}
if (messageBuilder == null) {
messageBuilder = MessageBuilder.withPayload(payload);
//
if (structMessage instanceof Map && ((Map<String, Object>) structMessage).containsKey("headers")) {
builder.copyHeaders((Map<String, Object>) ((Map<String, Object>) structMessage).get("headers"));
}
if (awsContext != null) {
messageBuilder.setHeader(AWS_CONTEXT, awsContext);
}
logger.info("Incoming request headers: " + headers);
return messageBuilder.copyHeaders(headers).build();
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private static MessageBuilder createMessageBuilderForPOJOFunction(JsonMapper objectMapper, Map request) {
Object body = request.remove("body");
try {
body = body instanceof String
? String.valueOf(body).getBytes(StandardCharsets.UTF_8)
: objectMapper.toJson(body);
}
catch (Exception e) {
throw new IllegalStateException(e);
}
logger.info("Body is " + body);
MessageBuilder messageBuilder = MessageBuilder.withPayload(body).copyHeaders(request);
messageBuilder.setHeader(AWS_API_GATEWAY, true);
return messageBuilder;
requestMessage = builder.build();
return requestMessage;
}
private static byte[] extractPayload(Message<Object> msg, JsonMapper objectMapper) {

View File

@@ -58,6 +58,9 @@ class AWSTypesMessageConverter extends JsonMessageConverter {
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_API_GATEWAY) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_API_GATEWAY))) {
return true;
}
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_EVENT) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_EVENT))) {
return true;
}
return false;
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2021-2021 the original author or authors.
* Copyright 2021-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,10 +22,7 @@ import java.net.SocketException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -45,7 +42,6 @@ import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.web.client.RestTemplate;
@@ -137,9 +133,11 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
}
if (response != null) {
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders().getContentType());
Message<byte[]> eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8),
fromHttp(response.getHeaders()), function.getInputType(), mapper);
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders());
Message<byte[]> eventMessage = AWSLambdaUtils
.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8), function.getInputType(), function.isSupplier(), mapper);
if (logger.isDebugEnabled()) {
logger.debug("Event message: " + eventMessage);
}
@@ -206,7 +204,8 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
return null;
}
private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, MediaType contentType) {
private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, HttpHeaders httpHeaders) {
MediaType contentType = httpHeaders.getContentType();
String handlerName = environment.getProperty("DEFAULT_HANDLER");
if (logger.isDebugEnabled()) {
logger.debug("Value of DEFAULT_HANDLER env: " + handlerName);
@@ -235,6 +234,15 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
function = functionCatalog.lookup(handlerName, contentType.toString());
}
if (function == null) {
logger.info("Could not determine DEFAULT_HANDLER, _HANDLER or 'spring.cloud.function.definition'");
handlerName = httpHeaders.getFirst("spring.cloud.function.definition");
if (logger.isDebugEnabled()) {
logger.debug("Value of 'spring.cloud.function.definition' header: " + handlerName);
}
function = functionCatalog.lookup(handlerName, contentType.toString());
}
Assert.notNull(function, "Failed to locate function. Tried locating default function, "
+ "function by 'DEFAULT_HANDLER', '_HANDLER' env variable as well as'spring.cloud.function.definition'. "
+ "Functions available in catalog are: " + functionCatalog.getNames(null));
@@ -244,25 +252,6 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
return function;
}
private MessageHeaders fromHttp(HttpHeaders headers) {
Map<String, Object> map = new LinkedHashMap<>();
for (String name : headers.keySet()) {
Collection<?> values = multi(headers.get(name));
name = name.toLowerCase();
Object value = values == null ? null
: (values.size() == 1 ? values.iterator().next() : values);
if (name.toLowerCase().equals(HttpHeaders.CONTENT_TYPE.toLowerCase())) {
name = MessageHeaders.CONTENT_TYPE;
}
map.put(name, value);
}
return new MessageHeaders(map);
}
private Collection<?> multi(Object value) {
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
}
private static String extractVersion() {
String path = CustomRuntimeEventLoop.class.getProtectionDomain().getCodeSource().getLocation().toString();
int endIndex = path.lastIndexOf('.');

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2019-2021 the original author or authors.
* Copyright 2019-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,13 +19,10 @@ package org.springframework.cloud.function.adapter.aws;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.amazonaws.services.lambda.runtime.Context;
@@ -53,7 +50,6 @@ import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -80,35 +76,11 @@ public class FunctionInvoker implements RequestStreamHandler {
this.start();
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@SuppressWarnings({ "rawtypes" })
@Override
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
final byte[] payload = StreamUtils.copyToByteArray(input);
if (logger.isInfoEnabled()) {
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}
Object structMessage = this.jsonMapper.fromJson(payload, Object.class);
boolean isApiGateway = structMessage instanceof Map
&& (((Map) structMessage).containsKey("httpMethod") ||
(((Map) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version")));
// TODO we should eventually completely delegate to message converter
Message requestMessage;
if (isApiGateway) {
MessageBuilder builder = MessageBuilder.withPayload(payload).setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true);
if (structMessage instanceof Map && ((Map) structMessage).containsKey("headers")) {
builder.copyHeaders((Map) ((Map) structMessage).get("headers"));
}
requestMessage = builder.build();
}
else {
requestMessage = AWSLambdaUtils
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.jsonMapper, context);
}
Message requestMessage = AWSLambdaUtils
.generateMessage(StreamUtils.copyToByteArray(input), this.function.getInputType(), this.function.isSupplier(), jsonMapper);
Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);