GH-751 Ensure RoutingFunction can be applied when function input type is AWS type

This commit is contained in:
Oleg Zhurakousky
2021-10-28 12:38:27 +02:00
parent 037f1b8bfe
commit eeb5448a7d
9 changed files with 292 additions and 59 deletions

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2021-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.function.adapter.aws;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.converter.MessageConverter;
/**
*
* @author Oleg Zhurakousky
* @since 3.2
*
*/
@Configuration(proxyBeanMethods = false)
public class AWSCompanionAutoConfiguration {
@Bean
public MessageConverter awsTypesConverter(JsonMapper jsonMapper) {
return new AWSTypesMessageConverter(jsonMapper);
}
}

View File

@@ -17,12 +17,9 @@
package org.springframework.cloud.function.adapter.aws;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@@ -32,17 +29,11 @@ import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.http.HttpStatus;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
@@ -58,18 +49,18 @@ final class AWSLambdaUtils {
private static Log logger = LogFactory.getLog(AWSLambdaUtils.class);
private static final String AWS_API_GATEWAY = "aws-api-gateway";
static final String AWS_API_GATEWAY = "aws-api-gateway";
private AWSLambdaUtils() {
}
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
Type inputType, ObjectMapper objectMapper) {
Type inputType, JsonMapper objectMapper) {
return generateMessage(payload, headers, inputType, objectMapper, null);
}
private static boolean isSupportedAWSType(Type inputType) {
static boolean isSupportedAWSType(Type inputType) {
String typeName = inputType.getTypeName();
return typeName.equals("com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent")
|| typeName.equals("com.amazonaws.services.lambda.runtime.events.S3Event")
@@ -81,7 +72,7 @@ final class AWSLambdaUtils {
@SuppressWarnings({ "unchecked", "rawtypes" })
public static Message<byte[]> generateMessage(byte[] payload, MessageHeaders headers,
Type inputType, ObjectMapper objectMapper, @Nullable Context awsContext) {
Type inputType, JsonMapper objectMapper, @Nullable Context awsContext) {
if (logger.isInfoEnabled()) {
logger.info("Incoming JSON Event: " + new String(payload));
@@ -102,12 +93,9 @@ final class AWSLambdaUtils {
}
}
else {
if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) {
configureObjectMapper(objectMapper);
}
Object request;
try {
request = objectMapper.readValue(payload, Object.class);
request = objectMapper.fromJson(payload, Object.class);
}
catch (Exception e) {
throw new IllegalStateException(e);
@@ -154,12 +142,12 @@ final class AWSLambdaUtils {
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private static MessageBuilder createMessageBuilderForPOJOFunction(ObjectMapper objectMapper, Map request) {
private static MessageBuilder createMessageBuilderForPOJOFunction(JsonMapper objectMapper, Map request) {
Object body = request.remove("body");
try {
body = body instanceof String
? String.valueOf(body).getBytes(StandardCharsets.UTF_8)
: objectMapper.writeValueAsBytes(body);
: objectMapper.toJson(body);
}
catch (Exception e) {
throw new IllegalStateException(e);
@@ -173,7 +161,7 @@ final class AWSLambdaUtils {
@SuppressWarnings({ "rawtypes", "unchecked" })
public static byte[] generateOutput(Message requestMessage, Message<byte[]> responseMessage,
ObjectMapper objectMapper, Type functionOutputType) {
JsonMapper objectMapper, Type functionOutputType) {
Class<?> outputClass = FunctionTypeUtils.getRawType(functionOutputType);
if (outputClass != null) {
@@ -184,9 +172,6 @@ final class AWSLambdaUtils {
}
}
if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) {
configureObjectMapper(objectMapper);
}
byte[] responseBytes = responseMessage == null ? "\"OK\"".getBytes() : responseMessage.getPayload();
if (requestMessage.getHeaders().containsKey(AWS_API_GATEWAY) && ((boolean) requestMessage.getHeaders().get(AWS_API_GATEWAY))) {
Map<String, Object> response = new HashMap<String, Object>();
@@ -218,7 +203,7 @@ final class AWSLambdaUtils {
}
try {
responseBytes = objectMapper.writeValueAsBytes(response);
responseBytes = objectMapper.toJson(response);
}
catch (Exception e) {
throw new IllegalStateException("Failed to serialize AWS Lambda output", e);
@@ -227,23 +212,6 @@ final class AWSLambdaUtils {
return responseBytes;
}
private static void configureObjectMapper(ObjectMapper objectMapper) {
SimpleModule module = new SimpleModule();
module.addDeserializer(Date.class, new JsonDeserializer<Date>() {
@Override
public Date deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
throws IOException {
Calendar calendar = Calendar.getInstance();
calendar.setTimeInMillis(jsonParser.getValueAsLong());
return calendar.getTime();
}
});
objectMapper.registerModule(module);
objectMapper.registerModule(new JodaModule());
objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true);
}
private static boolean isRequestKinesis(Message<Object> requestMessage) {
return requestMessage.getHeaders().containsKey("Records");
}

View File

@@ -0,0 +1,137 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.function.adapter.aws;
import java.io.ByteArrayInputStream;
import java.util.Map;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.events.LambdaEventSerializers;
import org.springframework.cloud.function.cloudevent.CloudEventMessageUtils;
import org.springframework.cloud.function.context.config.JsonMessageConverter;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.MimeType;
/**
* Implementation of {@link MessageConverter} which uses Jackson or Gson libraries to do the
* actual conversion via {@link JsonMapper} instance.
*
* @author Oleg Zhurakousky
*
* @since 3.2
*/
class AWSTypesMessageConverter extends JsonMessageConverter {
private final JsonMapper jsonMapper;
AWSTypesMessageConverter(JsonMapper jsonMapper) {
this(jsonMapper, new MimeType("application", "json"), new MimeType(CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getType(),
CloudEventMessageUtils.APPLICATION_CLOUDEVENTS.getSubtype() + "+json"));
}
AWSTypesMessageConverter(JsonMapper jsonMapper, MimeType... supportedMimeTypes) {
super(jsonMapper, supportedMimeTypes);
this.jsonMapper = jsonMapper;
}
@Override
protected boolean canConvertFrom(Message<?> message, @Nullable Class<?> targetClass) {
//if (targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
if (message.getHeaders().containsKey(AWSLambdaUtils.AWS_API_GATEWAY) && ((boolean) message.getHeaders().get(AWSLambdaUtils.AWS_API_GATEWAY))) {
return true;
}
return false;
}
@Override
protected Object convertFromInternal(Message<?> message, Class<?> targetClass, @Nullable Object conversionHint) {
if (message.getPayload().getClass().isAssignableFrom(targetClass)) {
return message.getPayload();
}
if (targetClass.getPackage().getName().startsWith("com.amazonaws.services.lambda.runtime.events")) {
PojoSerializer<?> serializer = LambdaEventSerializers.serializerFor(targetClass, Thread.currentThread().getContextClassLoader());
Object event = serializer.fromJson(new ByteArrayInputStream((byte[]) message.getPayload()));
return event;
}
else {
Map<String, String> structMessage = this.jsonMapper.fromJson(message.getPayload(), Map.class);
if (targetClass.isAssignableFrom(Map.class)) {
return structMessage;
}
else {
Object body = structMessage.get("body");
Object convertedResult = this.jsonMapper.fromJson(body, targetClass);
return convertedResult;
}
}
}
@Override
protected boolean canConvertTo(Object payload, @Nullable MessageHeaders headers) {
if (!supportsMimeType(headers)) {
return false;
}
return true;
}
@Override
protected Object convertToInternal(Object payload, @Nullable MessageHeaders headers,
@Nullable Object conversionHint) {
if (headers.containsKey(AWSLambdaUtils.AWS_API_GATEWAY) && ((boolean) headers.get(AWSLambdaUtils.AWS_API_GATEWAY))) {
// AtomicReference<MessageHeaders> headersRef = new AtomicReference<>();
// int statusCode = HttpStatus.OK.value();
// if (responseMessage != null) {
// headers.set(responseMessage.getHeaders());
// statusCode = headers.get().containsKey("statusCode")
// ? (int) headers.get().get("statusCode")
// : HttpStatus.OK.value();
// }
//
// response.put("statusCode", statusCode);
// if (isRequestKinesis(requestMessage)) {
// HttpStatus httpStatus = HttpStatus.valueOf(statusCode);
// response.put("statusDescription", httpStatus.toString());
// }
//
// String body = responseMessage == null
// ? "\"OK\"" : new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "");
// response.put("body", body);
//
// if (responseMessage != null) {
// Map<String, String> responseHeaders = new HashMap<>();
// headers.get().keySet().forEach(key -> responseHeaders.put(key, headers.get().get(key).toString()));
// response.put("headers", responseHeaders);
// }
//
// try {
// responseBytes = objectMapper.toJson(response);
// }
// catch (Exception e) {
// throw new IllegalStateException("Failed to serialize AWS Lambda output", e);
// }
}
return jsonMapper.toJson(payload);
}
}

View File

@@ -27,12 +27,12 @@ import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.env.Environment;
@@ -95,7 +95,7 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
RequestEntity<Void> requestEntity = RequestEntity.get(URI.create(eventUri)).build();
FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class);
RestTemplate rest = new RestTemplate();
ObjectMapper mapper = context.getBean(ObjectMapper.class);
JsonMapper mapper = context.getBean(JsonMapper.class);
logger.info("Entering event loop");
while (this.isRunning()) {

View File

@@ -21,14 +21,22 @@ import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
@@ -39,6 +47,8 @@ import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.FunctionalSpringApplication;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.cloud.function.context.config.RoutingFunction;
import org.springframework.cloud.function.json.JacksonMapper;
import org.springframework.cloud.function.json.JsonMapper;
import org.springframework.cloud.function.utils.FunctionClassUtils;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
@@ -64,7 +74,7 @@ public class FunctionInvoker implements RequestStreamHandler {
private static Log logger = LogFactory.getLog(FunctionInvoker.class);
private ObjectMapper objectMapper;
private JsonMapper jsonMapper;
private FunctionInvocationWrapper function;
@@ -81,8 +91,18 @@ public class FunctionInvoker implements RequestStreamHandler {
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}
Message requestMessage = AWSLambdaUtils
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.objectMapper, context);
Object structMessage = this.jsonMapper.fromJson(payload, Object.class);
boolean isApiGateway = structMessage instanceof Map
&& (((Map) structMessage).containsKey("httpMethod") ||
(((Map) structMessage).containsKey("routeKey") && ((Map) structMessage).containsKey("version")));
// TODO we should eventually completely delegate to message converter
//Message requestMessage = MessageBuilder.withPayload(payload).setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true).build();
Message requestMessage = isApiGateway
? MessageBuilder.withPayload(payload).setHeader(AWSLambdaUtils.AWS_API_GATEWAY, true).build()
: AWSLambdaUtils.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.jsonMapper, context);
try {
Object response = this.function.apply(requestMessage);
@@ -99,7 +119,7 @@ public class FunctionInvoker implements RequestStreamHandler {
APIGatewayProxyResponseEvent event = new APIGatewayProxyResponseEvent();
event.setStatusCode(HttpStatus.EXPECTATION_FAILED.value());
event.setBody(exception.getMessage());
return this.objectMapper.writeValueAsBytes(event);
return this.jsonMapper.toJson(event);
}
@SuppressWarnings("unchecked")
@@ -124,26 +144,45 @@ public class FunctionInvoker implements RequestStreamHandler {
logger.info("OUTPUT: " + output + " - " + output.getClass().getName());
}
byte[] payload = this.objectMapper.writeValueAsBytes(output);
byte[] payload = this.jsonMapper.toJson(output);
responseMessage = MessageBuilder.withPayload(payload).build();
}
else {
responseMessage = (Message<byte[]>) output;
}
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper, function.getOutputType());
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.jsonMapper, function.getOutputType());
}
private void start() {
Class<?> startClass = FunctionClassUtils.getStartClass();
String[] properties = new String[] {"--spring.cloud.function.web.export.enabled=false", "--spring.main.web-application-type=none"};
ConfigurableApplicationContext context = ApplicationContextInitializer.class.isAssignableFrom(startClass)
? FunctionalSpringApplication.run(startClass, properties)
: SpringApplication.run(FunctionClassUtils.getStartClass(), properties);
? FunctionalSpringApplication.run(new Class[] {startClass, AWSCompanionAutoConfiguration.class}, properties)
: SpringApplication.run(new Class[] {startClass, AWSCompanionAutoConfiguration.class}, properties);
Environment environment = context.getEnvironment();
String functionName = environment.getProperty("spring.cloud.function.definition");
FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class);
this.objectMapper = context.getBean(ObjectMapper.class);
this.jsonMapper = context.getBean(JsonMapper.class);
if (this.jsonMapper instanceof JacksonMapper) {
((JacksonMapper) this.jsonMapper).configureObjectMapper(objectMapper -> {
if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) {
SimpleModule module = new SimpleModule();
module.addDeserializer(Date.class, new JsonDeserializer<Date>() {
@Override
public Date deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
throws IOException {
Calendar calendar = Calendar.getInstance();
calendar.setTimeInMillis(jsonParser.getValueAsLong());
return calendar.getTime();
}
});
objectMapper.registerModule(module);
objectMapper.registerModule(new JodaModule());
objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true);
}
});
}
if (logger.isInfoEnabled()) {
logger.info("Locating function: '" + functionName + "'");

View File

@@ -698,14 +698,24 @@ public class FunctionInvokerTests {
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
ObjectMapper mapper = new ObjectMapper();
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("HELLO");
System.clearProperty("spring.cloud.function.definition");
System.setProperty("spring.cloud.function.routing-expression", "'uppercase'");
invoker = new FunctionInvoker();
targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
result = this.mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("HELLO");
}
@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayMapEventBody() throws Exception {
public void testApiGatewayPojoEventBody() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "uppercasePojo");
FunctionInvoker invoker = new FunctionInvoker();
@@ -716,6 +726,16 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("JIM LAHEY");
System.clearProperty("spring.cloud.function.definition");
System.setProperty("spring.cloud.function.routing-expression", "'uppercasePojo'");
invoker = new FunctionInvoker();
targetStream = new ByteArrayInputStream(this.apiGatewayEventWithStructuredBody.getBytes());
output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
result = this.mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("JIM LAHEY");
}
@SuppressWarnings("rawtypes")
@@ -732,6 +752,16 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("hello");
System.clearProperty("spring.cloud.function.definition");
System.setProperty("spring.cloud.function.routing-expression", "'inputApiEvent'");
invoker = new FunctionInvoker();
targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
result = this.mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
}
@SuppressWarnings("rawtypes")
@@ -748,6 +778,16 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("Hello from Lambda");
System.clearProperty("spring.cloud.function.definition");
System.setProperty("spring.cloud.function.routing-expression", "'inputApiV2Event'");
invoker = new FunctionInvoker();
targetStream = new ByteArrayInputStream(this.apiGatewayV2Event.getBytes());
output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
result = this.mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("Hello from Lambda");
}
@SuppressWarnings("rawtypes")
@@ -1128,7 +1168,9 @@ public class FunctionInvokerTests {
@Bean
public Function<Person, String> uppercasePojo() {
return v -> v.getName().toUpperCase();
return v -> {
return v.getName().toUpperCase();
};
}
@Bean