GH-662 Fix support for reactive functions in AWS

This commit also includes other minor fixes around CustomRuntime which was getting in the way of this specific issue
 Added lookup for _HANDLER env variable
 Added few tests (will need more)
 Added support for Iterable for reactive functions

Resolves #662
This commit is contained in:
Oleg Zhurakousky
2021-04-08 15:56:46 +02:00
parent cf58cdc700
commit fc42819357
7 changed files with 134 additions and 45 deletions

View File

@@ -122,6 +122,9 @@ final class AWSLambdaUtils {
}
}
}
else if (request instanceof Iterable) {
messageBuilder = MessageBuilder.withPayload(request);
}
if (messageBuilder == null) {
messageBuilder = MessageBuilder.withPayload(payload);
}

View File

@@ -28,13 +28,9 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
@@ -52,9 +48,7 @@ import org.springframework.web.client.RestTemplate;
* @since 3.1.1
*
*/
@Configuration
@ConditionalOnProperty("AWS_LAMBDA_RUNTIME_API")
public class CustomRuntimeEventLoop {
final class CustomRuntimeEventLoop {
private static Log logger = LogFactory.getLog(CustomRuntimeEventLoop.class);
@@ -62,10 +56,7 @@ public class CustomRuntimeEventLoop {
private static final String LAMBDA_RUNTIME_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/next";
private static final String LAMBDA_INVOCATION_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/response";
@Bean
@ConditionalOnProperty("AWS_LAMBDA_RUNTIME_API")
public CommandLineRunner backgrounder(ApplicationContext applicationContext) {
return args -> eventLoop(applicationContext);
private CustomRuntimeEventLoop() {
}
@SuppressWarnings("unchecked")
@@ -124,12 +115,16 @@ public class CustomRuntimeEventLoop {
private static FunctionInvocationWrapper locateFunction(FunctionCatalog functionCatalog, MediaType contentType) {
String handlerName = System.getenv("DEFAULT_HANDLER");
FunctionInvocationWrapper function = functionCatalog.lookup(handlerName, contentType.toString());
if (function == null) {
handlerName = System.getenv("_HANDLER");
}
function = functionCatalog.lookup(handlerName, contentType.toString());
if (function == null) {
handlerName = System.getenv("spring.cloud.function.definition");
}
function = functionCatalog.lookup(handlerName, contentType.toString());
Assert.notNull(function, "Failed to locate function. Tried locating default function, "
+ "function by '_HANDLER' env variable as well as'spring.cloud.function.definition'.");
+ "function by 'DEFAULT_HANDLER', '_HANDLER' env variable as well as'spring.cloud.function.definition'.");
if (function != null && logger.isInfoEnabled()) {
logger.info("Located function " + function.getFunctionDefinition());
}

View File

@@ -47,7 +47,8 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
CommandLineRunner.class, () -> args -> CustomRuntimeEventLoop.eventLoop(context));
}
}
else if (ContextFunctionCatalogInitializer.enabled
else
if (ContextFunctionCatalogInitializer.enabled
&& context.getEnvironment().getProperty("spring.functional.enabled", Boolean.class, false)) {
if (context.getBeanFactory().getBeanNamesForType(DestinationResolver.class, false, false).length == 0) {
context.registerBean(LambdaDestinationResolver.class, () -> new LambdaDestinationResolver());
@@ -60,6 +61,8 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
private boolean isCustomRuntime() {
String handler = System.getenv("_HANDLER");
if (StringUtils.hasText(handler)) {
handler = handler.split(":")[0];
logger.info("AWS Handler: " + handler);
try {
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(handler);
if (FunctionInvoker.class.isAssignableFrom(clazz) || AbstractSpringFunctionAdapterInitializer.class.isAssignableFrom(clazz)) {

View File

@@ -19,21 +19,18 @@ package org.springframework.cloud.function.adapter.aws;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Calendar;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.boot.SpringApplication;
import org.springframework.cloud.function.context.FunctionCatalog;
@@ -43,6 +40,7 @@ import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
@@ -67,33 +65,70 @@ public class FunctionInvoker implements RequestStreamHandler {
this.start();
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings("rawtypes")
@Override
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
final byte[] payload = StreamUtils.copyToByteArray(input);
if (logger.isInfoEnabled()) {
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
}
Message requestMessage = AWSLambdaUtils
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.objectMapper, context);
Message<byte[]> responseMessage = (Message<byte[]>) this.function.apply(requestMessage);
byte[] responseBytes = AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper);
Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);
StreamUtils.copy(responseBytes, output);
}
@SuppressWarnings("unchecked")
private byte[] buildResult(Message<?> requestMessage, Object output) throws IOException {
Message<byte[]> responseMessage;
if (output instanceof Publisher<?>) {
List<Object> result = new ArrayList<>();
for (Object value : Flux.from((Publisher<?>) output).toIterable()) {
if (logger.isInfoEnabled()) {
logger.info("Response value: " + value);
}
result.add(value);
}
if (result.size() > 1) {
output = result;
}
else {
output = result.get(0);
}
if (logger.isInfoEnabled()) {
logger.info("OUTPUT: " + output + " - " + output.getClass().getName());
}
byte[] payload = this.objectMapper.writeValueAsBytes(output);
responseMessage = MessageBuilder.withPayload(payload).build();
}
else {
responseMessage = (Message<byte[]>) output;
}
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper);
}
private void start() {
ConfigurableApplicationContext context = SpringApplication.run(FunctionClassUtils.getStartClass());
Environment environment = context.getEnvironment();
String functionName = environment.getProperty("spring.cloud.function.definition");
FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class);
this.objectMapper = context.getBean(ObjectMapper.class);
//this.configureObjectMapper();
if (logger.isInfoEnabled()) {
logger.info("Locating function: '" + functionName + "'");
}
this.function = functionCatalog.lookup(functionName, "application/json");
if (this.function.isOutputTypePublisher()) {
this.function.setSkipOutputConversion(true);
}
Assert.notNull(this.function, "Failed to lookup function " + functionName);
if (!StringUtils.hasText(functionName)) {
@@ -104,20 +139,4 @@ public class FunctionInvoker implements RequestStreamHandler {
logger.info("Located function: '" + functionName + "'");
}
}
private void configureObjectMapper() {
SimpleModule module = new SimpleModule();
module.addDeserializer(Date.class, new JsonDeserializer<Date>() {
@Override
public Date deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
throws IOException {
Calendar calendar = Calendar.getInstance();
calendar.setTimeInMillis(jsonParser.getValueAsLong());
return calendar.getTime();
}
});
this.objectMapper.registerModule(module);
this.objectMapper.registerModule(new JodaModule());
this.objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true);
}
}

View File

@@ -32,6 +32,7 @@ import com.amazonaws.services.lambda.runtime.events.SNSEvent;
import com.amazonaws.services.lambda.runtime.events.SQSEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.context.annotation.Bean;
@@ -51,6 +52,8 @@ public class FunctionInvokerTests {
ObjectMapper mapper = new ObjectMapper();
String jsonCollection = "[\"Ricky\",\"Julien\",\"Bubbles\"]";
String sampleLBEvent = "{" +
" \"requestContext\": {" +
" \"elb\": {" +
@@ -360,6 +363,20 @@ public class FunctionInvokerTests {
" \"isBase64Encoded\": false\n" +
"}";
@Test
public void testCollection() throws Exception {
System.setProperty("MAIN_CLASS", SampleConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "echoStringReactive");
FunctionInvoker invoker = new FunctionInvoker();
InputStream targetStream = new ByteArrayInputStream(this.jsonCollection.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
String result = new String(output.toByteArray(), StandardCharsets.UTF_8);
assertThat(result).isEqualTo(this.jsonCollection);
}
@Test
public void testKinesisStringEvent() throws Exception {
System.setProperty("MAIN_CLASS", KinesisConfiguration.class.getName());
@@ -689,6 +706,20 @@ public class FunctionInvokerTests {
assertThat(result.get("body")).isEqualTo("\"OK\"");
}
@EnableAutoConfiguration
@Configuration
public static class SampleConfiguration {
@Bean
public Function<String, String> echoString() {
return v -> v;
}
@Bean
public Function<Flux<String>, Flux<String>> echoStringReactive() {
return v -> v;
}
}
@EnableAutoConfiguration
@Configuration
public static class KinesisConfiguration {