GH-662 Fix support for reactive functions in AWS
This commit also includes other minor fixes around CustomRuntime which was getting in the way of this specific issue Added lookup for _HANDLER env variable Added few tests (will need more) Added support for Iterable for reactive functions Resolves #662
This commit is contained in:
@@ -122,6 +122,9 @@ final class AWSLambdaUtils {
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (request instanceof Iterable) {
|
||||
messageBuilder = MessageBuilder.withPayload(request);
|
||||
}
|
||||
if (messageBuilder == null) {
|
||||
messageBuilder = MessageBuilder.withPayload(payload);
|
||||
}
|
||||
|
||||
@@ -28,13 +28,9 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.cloud.function.context.FunctionCatalog;
|
||||
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.RequestEntity;
|
||||
@@ -52,9 +48,7 @@ import org.springframework.web.client.RestTemplate;
|
||||
* @since 3.1.1
|
||||
*
|
||||
*/
|
||||
@Configuration
|
||||
@ConditionalOnProperty("AWS_LAMBDA_RUNTIME_API")
|
||||
public class CustomRuntimeEventLoop {
|
||||
final class CustomRuntimeEventLoop {
|
||||
|
||||
private static Log logger = LogFactory.getLog(CustomRuntimeEventLoop.class);
|
||||
|
||||
@@ -62,10 +56,7 @@ public class CustomRuntimeEventLoop {
|
||||
private static final String LAMBDA_RUNTIME_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/next";
|
||||
private static final String LAMBDA_INVOCATION_URL_TEMPLATE = "http://{0}/{1}/runtime/invocation/{2}/response";
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty("AWS_LAMBDA_RUNTIME_API")
|
||||
public CommandLineRunner backgrounder(ApplicationContext applicationContext) {
|
||||
return args -> eventLoop(applicationContext);
|
||||
private CustomRuntimeEventLoop() {
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@@ -124,12 +115,16 @@ public class CustomRuntimeEventLoop {
|
||||
private static FunctionInvocationWrapper locateFunction(FunctionCatalog functionCatalog, MediaType contentType) {
|
||||
String handlerName = System.getenv("DEFAULT_HANDLER");
|
||||
FunctionInvocationWrapper function = functionCatalog.lookup(handlerName, contentType.toString());
|
||||
if (function == null) {
|
||||
handlerName = System.getenv("_HANDLER");
|
||||
}
|
||||
function = functionCatalog.lookup(handlerName, contentType.toString());
|
||||
if (function == null) {
|
||||
handlerName = System.getenv("spring.cloud.function.definition");
|
||||
}
|
||||
function = functionCatalog.lookup(handlerName, contentType.toString());
|
||||
Assert.notNull(function, "Failed to locate function. Tried locating default function, "
|
||||
+ "function by '_HANDLER' env variable as well as'spring.cloud.function.definition'.");
|
||||
+ "function by 'DEFAULT_HANDLER', '_HANDLER' env variable as well as'spring.cloud.function.definition'.");
|
||||
if (function != null && logger.isInfoEnabled()) {
|
||||
logger.info("Located function " + function.getFunctionDefinition());
|
||||
}
|
||||
|
||||
@@ -47,7 +47,8 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
|
||||
CommandLineRunner.class, () -> args -> CustomRuntimeEventLoop.eventLoop(context));
|
||||
}
|
||||
}
|
||||
else if (ContextFunctionCatalogInitializer.enabled
|
||||
else
|
||||
if (ContextFunctionCatalogInitializer.enabled
|
||||
&& context.getEnvironment().getProperty("spring.functional.enabled", Boolean.class, false)) {
|
||||
if (context.getBeanFactory().getBeanNamesForType(DestinationResolver.class, false, false).length == 0) {
|
||||
context.registerBean(LambdaDestinationResolver.class, () -> new LambdaDestinationResolver());
|
||||
@@ -60,6 +61,8 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
|
||||
private boolean isCustomRuntime() {
|
||||
String handler = System.getenv("_HANDLER");
|
||||
if (StringUtils.hasText(handler)) {
|
||||
handler = handler.split(":")[0];
|
||||
logger.info("AWS Handler: " + handler);
|
||||
try {
|
||||
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(handler);
|
||||
if (FunctionInvoker.class.isAssignableFrom(clazz) || AbstractSpringFunctionAdapterInitializer.class.isAssignableFrom(clazz)) {
|
||||
|
||||
@@ -19,21 +19,18 @@ package org.springframework.cloud.function.adapter.aws;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.Calendar;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
import com.amazonaws.services.lambda.runtime.Context;
|
||||
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
|
||||
import com.fasterxml.jackson.core.JsonParser;
|
||||
import com.fasterxml.jackson.databind.DeserializationContext;
|
||||
import com.fasterxml.jackson.databind.JsonDeserializer;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||
import com.fasterxml.jackson.datatype.joda.JodaModule;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.reactivestreams.Publisher;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.cloud.function.context.FunctionCatalog;
|
||||
@@ -43,6 +40,7 @@ import org.springframework.context.ConfigurableApplicationContext;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
@@ -67,33 +65,70 @@ public class FunctionInvoker implements RequestStreamHandler {
|
||||
this.start();
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
@SuppressWarnings("rawtypes")
|
||||
@Override
|
||||
public void handleRequest(InputStream input, OutputStream output, Context context) throws IOException {
|
||||
final byte[] payload = StreamUtils.copyToByteArray(input);
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("Received: " + new String(payload, StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
Message requestMessage = AWSLambdaUtils
|
||||
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.objectMapper, context);
|
||||
|
||||
Message<byte[]> responseMessage = (Message<byte[]>) this.function.apply(requestMessage);
|
||||
|
||||
byte[] responseBytes = AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper);
|
||||
Object response = this.function.apply(requestMessage);
|
||||
|
||||
byte[] responseBytes = this.buildResult(requestMessage, response);
|
||||
StreamUtils.copy(responseBytes, output);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private byte[] buildResult(Message<?> requestMessage, Object output) throws IOException {
|
||||
Message<byte[]> responseMessage;
|
||||
if (output instanceof Publisher<?>) {
|
||||
List<Object> result = new ArrayList<>();
|
||||
for (Object value : Flux.from((Publisher<?>) output).toIterable()) {
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("Response value: " + value);
|
||||
}
|
||||
result.add(value);
|
||||
}
|
||||
if (result.size() > 1) {
|
||||
output = result;
|
||||
}
|
||||
else {
|
||||
output = result.get(0);
|
||||
}
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("OUTPUT: " + output + " - " + output.getClass().getName());
|
||||
}
|
||||
|
||||
byte[] payload = this.objectMapper.writeValueAsBytes(output);
|
||||
responseMessage = MessageBuilder.withPayload(payload).build();
|
||||
}
|
||||
else {
|
||||
responseMessage = (Message<byte[]>) output;
|
||||
}
|
||||
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper);
|
||||
}
|
||||
|
||||
private void start() {
|
||||
ConfigurableApplicationContext context = SpringApplication.run(FunctionClassUtils.getStartClass());
|
||||
Environment environment = context.getEnvironment();
|
||||
String functionName = environment.getProperty("spring.cloud.function.definition");
|
||||
FunctionCatalog functionCatalog = context.getBean(FunctionCatalog.class);
|
||||
this.objectMapper = context.getBean(ObjectMapper.class);
|
||||
//this.configureObjectMapper();
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("Locating function: '" + functionName + "'");
|
||||
}
|
||||
|
||||
this.function = functionCatalog.lookup(functionName, "application/json");
|
||||
if (this.function.isOutputTypePublisher()) {
|
||||
this.function.setSkipOutputConversion(true);
|
||||
}
|
||||
Assert.notNull(this.function, "Failed to lookup function " + functionName);
|
||||
|
||||
if (!StringUtils.hasText(functionName)) {
|
||||
@@ -104,20 +139,4 @@ public class FunctionInvoker implements RequestStreamHandler {
|
||||
logger.info("Located function: '" + functionName + "'");
|
||||
}
|
||||
}
|
||||
|
||||
private void configureObjectMapper() {
|
||||
SimpleModule module = new SimpleModule();
|
||||
module.addDeserializer(Date.class, new JsonDeserializer<Date>() {
|
||||
@Override
|
||||
public Date deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
|
||||
throws IOException {
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
calendar.setTimeInMillis(jsonParser.getValueAsLong());
|
||||
return calendar.getTime();
|
||||
}
|
||||
});
|
||||
this.objectMapper.registerModule(module);
|
||||
this.objectMapper.registerModule(new JodaModule());
|
||||
this.objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import com.amazonaws.services.lambda.runtime.events.SNSEvent;
|
||||
import com.amazonaws.services.lambda.runtime.events.SQSEvent;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -51,6 +52,8 @@ public class FunctionInvokerTests {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
String jsonCollection = "[\"Ricky\",\"Julien\",\"Bubbles\"]";
|
||||
|
||||
String sampleLBEvent = "{" +
|
||||
" \"requestContext\": {" +
|
||||
" \"elb\": {" +
|
||||
@@ -360,6 +363,20 @@ public class FunctionInvokerTests {
|
||||
" \"isBase64Encoded\": false\n" +
|
||||
"}";
|
||||
|
||||
|
||||
@Test
|
||||
public void testCollection() throws Exception {
|
||||
System.setProperty("MAIN_CLASS", SampleConfiguration.class.getName());
|
||||
System.setProperty("spring.cloud.function.definition", "echoStringReactive");
|
||||
FunctionInvoker invoker = new FunctionInvoker();
|
||||
|
||||
InputStream targetStream = new ByteArrayInputStream(this.jsonCollection.getBytes());
|
||||
ByteArrayOutputStream output = new ByteArrayOutputStream();
|
||||
invoker.handleRequest(targetStream, output, null);
|
||||
String result = new String(output.toByteArray(), StandardCharsets.UTF_8);
|
||||
assertThat(result).isEqualTo(this.jsonCollection);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKinesisStringEvent() throws Exception {
|
||||
System.setProperty("MAIN_CLASS", KinesisConfiguration.class.getName());
|
||||
@@ -689,6 +706,20 @@ public class FunctionInvokerTests {
|
||||
assertThat(result.get("body")).isEqualTo("\"OK\"");
|
||||
}
|
||||
|
||||
@EnableAutoConfiguration
|
||||
@Configuration
|
||||
public static class SampleConfiguration {
|
||||
@Bean
|
||||
public Function<String, String> echoString() {
|
||||
return v -> v;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Flux<String>, Flux<String>> echoStringReactive() {
|
||||
return v -> v;
|
||||
}
|
||||
}
|
||||
|
||||
@EnableAutoConfiguration
|
||||
@Configuration
|
||||
public static class KinesisConfiguration {
|
||||
|
||||
Reference in New Issue
Block a user