Restructure and simplify AWS Custom Runtime code

Add Custom Runtime emulator to simplify integration testing
This commit is contained in:
Oleg Zhurakousky
2021-10-20 17:52:52 +02:00
parent f7112d1ef5
commit 2addf5af7d
8 changed files with 316 additions and 274 deletions

View File

@@ -16,6 +16,7 @@
package org.springframework.cloud.function.adapter.aws;
import java.net.SocketException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
@@ -23,16 +24,18 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.CommandLineRunner;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
@@ -50,7 +53,7 @@ import org.springframework.web.client.RestTemplate;
* @since 3.1.1
*
*/
public final class CustomRuntimeEventLoop implements CommandLineRunner {
public final class CustomRuntimeEventLoop implements SmartLifecycle {
private static Log logger = LogFactory.getLog(CustomRuntimeEventLoop.class);
@@ -60,23 +63,30 @@ public final class CustomRuntimeEventLoop implements CommandLineRunner {
private final ConfigurableApplicationContext applicationContext;
private volatile boolean running;
private ExecutorService executor = Executors.newSingleThreadExecutor();
public CustomRuntimeEventLoop(ConfigurableApplicationContext applicationContext) {
this.applicationContext = applicationContext;
}
@Override
public void run(String... args) throws Exception {
CustomRuntimeEventLoop.eventLoop(this.applicationContext, args);
public void run() {
this.running = true;
this.executor.execute(() -> {
eventLoop(this.applicationContext);
});
}
@SuppressWarnings("unchecked")
private static void eventLoop(ApplicationContext context, String... args) {
private void eventLoop(ConfigurableApplicationContext context) {
Environment environment = context.getEnvironment();
logger.info("Starting spring-cloud-function CustomRuntimeEventLoop");
if (logger.isDebugEnabled()) {
logger.debug("AWS LAMBDA ENVIRONMENT: " + System.getenv());
}
String runtimeApi = System.getenv("AWS_LAMBDA_RUNTIME_API");
String runtimeApi = environment.getProperty("AWS_LAMBDA_RUNTIME_API");
String eventUri = MessageFormat.format(LAMBDA_RUNTIME_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE);
if (logger.isDebugEnabled()) {
logger.debug("Event URI: " + eventUri);
@@ -88,49 +98,61 @@ public final class CustomRuntimeEventLoop implements CommandLineRunner {
ObjectMapper mapper = context.getBean(ObjectMapper.class);
logger.info("Entering event loop");
while (isContinue()) {
while (this.isRunning()) {
logger.debug("Attempting to get new event");
ResponseEntity<String> response = rest.exchange(requestEntity, String.class);
ResponseEntity<String> response = this.pollForData(rest, requestEntity);
if (logger.isDebugEnabled()) {
logger.debug("New Event received: " + response);
}
FunctionInvocationWrapper function = locateFunction(functionCatalog, response.getHeaders().getContentType());
Message<byte[]> eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8),
fromHttp(response.getHeaders()), function.getInputType(), mapper);
if (logger.isDebugEnabled()) {
logger.debug("Event message: " + eventMessage);
}
if (response != null) {
FunctionInvocationWrapper function = locateFunction(environment, functionCatalog, response.getHeaders().getContentType());
Message<byte[]> eventMessage = AWSLambdaUtils.generateMessage(response.getBody().getBytes(StandardCharsets.UTF_8),
fromHttp(response.getHeaders()), function.getInputType(), mapper);
if (logger.isDebugEnabled()) {
logger.debug("Event message: " + eventMessage);
}
String requestId = response.getHeaders().getFirst("Lambda-Runtime-Aws-Request-Id");
String invocationUrl = MessageFormat
.format(LAMBDA_INVOCATION_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE, requestId);
String requestId = response.getHeaders().getFirst("Lambda-Runtime-Aws-Request-Id");
String invocationUrl = MessageFormat
.format(LAMBDA_INVOCATION_URL_TEMPLATE, runtimeApi, LAMBDA_VERSION_DATE, requestId);
Message<byte[]> responseMessage = (Message<byte[]>) function.apply(eventMessage);
Message<byte[]> responseMessage = (Message<byte[]>) function.apply(eventMessage);
if (responseMessage != null && logger.isDebugEnabled()) {
logger.debug("Reply from function: " + responseMessage);
}
if (responseMessage != null && logger.isDebugEnabled()) {
logger.debug("Reply from function: " + responseMessage);
}
byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper, function.getOutputType());
ResponseEntity<Object> result = rest
.exchange(RequestEntity.post(URI.create(invocationUrl)).body(outputBody), Object.class);
byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper, function.getOutputType());
ResponseEntity<Object> result = rest
.exchange(RequestEntity.post(URI.create(invocationUrl)).body(outputBody), Object.class);
if (logger.isInfoEnabled()) {
logger.info("Result POST status: " + result.getStatusCode());
if (logger.isInfoEnabled()) {
logger.info("Result POST status: " + result.getStatusCode());
}
}
}
}
private static boolean isContinue() {
return Boolean.parseBoolean(System.getProperty("CustomRuntimeEventLoop.continue", "true"));
private ResponseEntity<String> pollForData(RestTemplate rest, RequestEntity<Void> requestEntity) {
try {
return rest.exchange(requestEntity, String.class);
}
catch (Exception e) {
if (e instanceof SocketException) {
this.stop();
// ignore
}
}
return null;
}
private static FunctionInvocationWrapper locateFunction(FunctionCatalog functionCatalog, MediaType contentType) {
String handlerName = System.getenv("DEFAULT_HANDLER");
private FunctionInvocationWrapper locateFunction(Environment environment, FunctionCatalog functionCatalog, MediaType contentType) {
String handlerName = environment.getProperty("DEFAULT_HANDLER");
FunctionInvocationWrapper function = functionCatalog.lookup(handlerName, contentType.toString());
if (function == null) {
handlerName = System.getenv("_HANDLER");
handlerName = environment.getProperty("_HANDLER");
function = functionCatalog.lookup(handlerName, contentType.toString());
}
@@ -139,7 +161,7 @@ public final class CustomRuntimeEventLoop implements CommandLineRunner {
}
if (function == null) {
handlerName = System.getenv("spring.cloud.function.definition");
handlerName = environment.getProperty("spring.cloud.function.definition");
function = functionCatalog.lookup(handlerName, contentType.toString());
}
@@ -156,7 +178,7 @@ public final class CustomRuntimeEventLoop implements CommandLineRunner {
return function;
}
private static MessageHeaders fromHttp(HttpHeaders headers) {
private MessageHeaders fromHttp(HttpHeaders headers) {
Map<String, Object> map = new LinkedHashMap<>();
for (String name : headers.keySet()) {
Collection<?> values = multi(headers.get(name));
@@ -171,7 +193,23 @@ public final class CustomRuntimeEventLoop implements CommandLineRunner {
return new MessageHeaders(map);
}
private static Collection<?> multi(Object value) {
private Collection<?> multi(Object value) {
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
}
@Override
public void start() {
this.run();
}
@Override
public void stop() {
this.executor.shutdownNow();
this.running = false;
}
@Override
public boolean isRunning() {
return this.running;
}
}

View File

@@ -19,12 +19,13 @@ package org.springframework.cloud.function.adapter.aws;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.CommandLineRunner;
import org.springframework.cloud.function.context.AbstractSpringFunctionAdapterInitializer;
import org.springframework.cloud.function.context.config.ContextFunctionCatalogInitializer;
import org.springframework.cloud.function.web.source.DestinationResolver;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.SmartLifecycle;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.util.StringUtils;
/**
@@ -37,14 +38,15 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
@Override
public void initialize(GenericApplicationContext context) {
Environment environment = context.getEnvironment();
if (logger.isDebugEnabled()) {
logger.debug("AWS Environment: " + System.getenv());
}
if (!this.isWebExportEnabled(context) && isCustomRuntime()) {
if (!this.isWebExportEnabled(context) && isCustomRuntime(environment)) {
if (context.getBeanFactory().getBeanNamesForType(CustomRuntimeEventLoop.class, false, false).length == 0) {
context.registerBean(StringUtils.uncapitalize(CustomRuntimeEventLoop.class.getSimpleName()),
CommandLineRunner.class, () -> new CustomRuntimeEventLoop(context));
SmartLifecycle.class, () -> new CustomRuntimeEventLoop(context));
}
}
else if (ContextFunctionCatalogInitializer.enabled
@@ -55,8 +57,8 @@ public class CustomRuntimeInitializer implements ApplicationContextInitializer<G
}
}
private boolean isCustomRuntime() {
String handler = System.getenv("_HANDLER");
private boolean isCustomRuntime(Environment environment) {
String handler = environment.getProperty("_HANDLER");
if (StringUtils.hasText(handler)) {
handler = handler.split(":")[0];
logger.info("AWS Handler: " + handler);

View File

@@ -0,0 +1,92 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.function.adapter.test.aws;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.web.servlet.context.ServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MimeTypeUtils;
/**
* AWS Custom Runtime emulator to be used for testing.
*
* @author Oleg Zhurakousky
* @since 3.2
*/
@EnableAutoConfiguration
public class AWSCustomRuntime {
BlockingQueue<Object> inputQueue = new ArrayBlockingQueue<>(3);
BlockingQueue<Message<String>> outputQueue = new ArrayBlockingQueue<>(3);
public AWSCustomRuntime(ServletWebServerApplicationContext context) {
int port = context.getWebServer().getPort();
System.setProperty("AWS_LAMBDA_RUNTIME_API", "localhost:" + port);
}
@Bean("2018-06-01/runtime/invocation/consume/response")
Consumer<Message<String>> consume() {
return v -> outputQueue.offer(v);
}
@SuppressWarnings("unchecked")
@Bean("2018-06-01/runtime/invocation/next")
Supplier<Message<String>> supply() {
return () -> {
try {
Object value = inputQueue.poll(Long.MAX_VALUE, TimeUnit.SECONDS);
if (!(value instanceof Message)) {
return MessageBuilder.withPayload((String) value)
.setHeader("Lambda-Runtime-Aws-Request-Id", "consume")
.setHeader("Content-Type",
MimeTypeUtils.APPLICATION_JSON)
.build();
}
else {
return (Message<String>) value;
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException(e);
}
};
}
public Message<String> exchange(Object input) {
inputQueue.offer(input);
try {
return outputQueue.poll(5000, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
}
}

View File

@@ -16,192 +16,93 @@
package org.springframework.cloud.function.adapter.aws;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.cloud.function.adapter.test.aws.AWSCustomRuntime;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.stereotype.Component;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.util.MimeTypeUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
*
* @author Oleg Zhurakousky
*/
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = "spring.main.web-application-type=servlet")
@ContextConfiguration(classes = {
CustomRuntimeEventLoopTest.CustomRuntimeEmulatorConfiguration.class })
public class CustomRuntimeEventLoopTest {
@LocalServerPort
private int port;
@Autowired
private CustomRuntimeEmulatorConfiguration configuration;
@SuppressWarnings("unchecked")
private Map<String, String> getEnvironment() throws Exception {
Map<String, String> env = System.getenv();
Field field = env.getClass().getDeclaredField("m");
field.setAccessible(true);
return (Map<String, String>) field.get(env);
}
@BeforeEach
public void before() {
System.setProperty("CustomRuntimeEventLoop.continue", "true");
}
@Test
@DirtiesContext
public void testDefaultFunctionLookup() throws Exception {
this.getEnvironment().put("AWS_LAMBDA_RUNTIME_API", "localhost:" + port);
this.getEnvironment().put("_HANDLER", "uppercase");
try (ConfigurableApplicationContext userContext =
new SpringApplicationBuilder(SingleFunctionConfiguration.class, AWSCustomRuntime.class)
.web(WebApplicationType.SERVLET)
.properties("_HANDLER=uppercase", "server.port=0")
.run()) {
configuration.inputQueue.clear();
configuration.inputQueue.addAll(Arrays.asList("\"ricky\"", "\"julien\"", "\"bubbles\""));
try (ConfigurableApplicationContext userContext = new SpringApplicationBuilder(SingleFunctionConfiguration.class)
.web(WebApplicationType.NONE).run(
"--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.main.lazy-initialization=true")) {
assertThat(configuration.output).size().isEqualTo(3);
assertThat(configuration.output.get(0)).isEqualTo("\"RICKY\"");
assertThat(configuration.output.get(1)).isEqualTo("\"JULIEN\"");
assertThat(configuration.output.get(2)).isEqualTo("\"BUBBLES\"");
AWSCustomRuntime aws = userContext.getBean(AWSCustomRuntime.class);
assertThat(aws.exchange("\"ricky\"").getPayload()).isEqualTo("\"RICKY\"");
assertThat(aws.exchange("\"julien\"").getPayload()).isEqualTo("\"JULIEN\"");
assertThat(aws.exchange("\"bubbles\"").getPayload()).isEqualTo("\"BUBBLES\"");
}
}
@Test
@DirtiesContext
public void testDefaultFunctionAsComponentLookup() throws Exception {
this.getEnvironment().put("AWS_LAMBDA_RUNTIME_API", "localhost:" + port);
this.getEnvironment().put("_HANDLER", "personFunction");
try (ConfigurableApplicationContext userContext =
new SpringApplicationBuilder(PersonFunction.class, AWSCustomRuntime.class)
.web(WebApplicationType.SERVLET)
.properties("_HANDLER=personFunction", "server.port=0")
.run()) {
configuration.inputQueue.clear();
configuration.inputQueue.addAll(Arrays.asList("\"ricky\"", "\"julien\"", "\"bubbles\""));
AWSCustomRuntime aws = userContext.getBean(AWSCustomRuntime.class);
try (ConfigurableApplicationContext userContext = new SpringApplicationBuilder(PersonFunction.class)
.web(WebApplicationType.NONE).run(
"--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.main.lazy-initialization=true")) {
assertThat(configuration.output).size().isEqualTo(3);
assertThat(configuration.output.get(0)).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(configuration.output.get(1)).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(configuration.output.get(2)).isEqualTo("{\"name\":\"BUBBLES\"}");
assertThat(aws.exchange("\"ricky\"").getPayload()).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(aws.exchange("\"julien\"").getPayload()).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(aws.exchange("\"bubbles\"").getPayload()).isEqualTo("{\"name\":\"BUBBLES\"}");
}
}
@Test
@DirtiesContext
public void test_HANDLERlookupAndPojoFunction() throws Exception {
this.getEnvironment().put("AWS_LAMBDA_RUNTIME_API", "localhost:" + port);
this.getEnvironment().put("_HANDLER", "uppercasePerson");
try (ConfigurableApplicationContext userContext =
new SpringApplicationBuilder(MultipleFunctionConfiguration.class, AWSCustomRuntime.class)
.web(WebApplicationType.SERVLET)
.properties("_HANDLER=uppercasePerson", "server.port=0")
.run()) {
configuration.inputQueue.clear();
configuration.inputQueue.addAll(Arrays.asList("{\"name\":\"ricky\"}",
"{\"name\":\"julien\"}", "{\"name\":\"bubbles\"}"));
try (ConfigurableApplicationContext userContext = new SpringApplicationBuilder(MultipleFunctionConfiguration.class)
.web(WebApplicationType.NONE).run(
"--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.main.lazy-initialization=true")) {
AWSCustomRuntime aws = userContext.getBean(AWSCustomRuntime.class);
assertThat(configuration.output).size().isEqualTo(3);
assertThat(configuration.output.get(0)).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(configuration.output.get(1)).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(configuration.output.get(2)).isEqualTo("{\"name\":\"BUBBLES\"}");
assertThat(aws.exchange("\"ricky\"").getPayload()).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(aws.exchange("\"julien\"").getPayload()).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(aws.exchange("\"bubbles\"").getPayload()).isEqualTo("{\"name\":\"BUBBLES\"}");
}
}
@Test
@DirtiesContext
public void test_definitionLookupAndComposition() throws Exception {
this.getEnvironment().put("AWS_LAMBDA_RUNTIME_API", "localhost:" + port);
System.setProperty("spring.cloud.function.definition", "toPersonJson|uppercasePerson");
try (ConfigurableApplicationContext userContext =
new SpringApplicationBuilder(MultipleFunctionConfiguration.class, AWSCustomRuntime.class)
.web(WebApplicationType.SERVLET)
.properties("_HANDLER=toPersonJson|uppercasePerson", "server.port=0")
.run()) {
configuration.inputQueue.clear();
configuration.inputQueue.addAll(Arrays.asList("\"ricky\"", "\"julien\"", "\"bubbles\""));
AWSCustomRuntime aws = userContext.getBean(AWSCustomRuntime.class);
try (ConfigurableApplicationContext userContext = new SpringApplicationBuilder(MultipleFunctionConfiguration.class)
.web(WebApplicationType.NONE).run(
"--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.main.lazy-initialization=true")) {
assertThat(configuration.output).size().isEqualTo(3);
assertThat(configuration.output.get(0)).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(configuration.output.get(1)).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(configuration.output.get(2)).isEqualTo("{\"name\":\"BUBBLES\"}");
}
}
@SpringBootConfiguration(proxyBeanMethods = false)
@EnableAutoConfiguration
protected static class CustomRuntimeEmulatorConfiguration {
BlockingQueue<String> inputQueue = new ArrayBlockingQueue<>(3);
List<String> output = new ArrayList<>();
@Bean("2018-06-01/runtime/invocation/consume/response")
public Consumer<Message<String>> consume() {
return v -> output.add(v.getPayload());
assertThat(aws.exchange("\"ricky\"").getPayload()).isEqualTo("{\"name\":\"RICKY\"}");
assertThat(aws.exchange("\"julien\"").getPayload()).isEqualTo("{\"name\":\"JULIEN\"}");
assertThat(aws.exchange("\"bubbles\"").getPayload()).isEqualTo("{\"name\":\"BUBBLES\"}");
}
@Bean("2018-06-01/runtime/invocation/next")
public Supplier<Message<String>> supply() {
return () -> {
try {
String value = inputQueue.poll(Long.MAX_VALUE, TimeUnit.SECONDS);
if (inputQueue.peek() == null) {
System.setProperty("CustomRuntimeEventLoop.continue", "false");
}
return MessageBuilder.withPayload(value)
.setHeader("Lambda-Runtime-Aws-Request-Id", "consume")
.setHeader("Content-Type",
MimeTypeUtils.APPLICATION_JSON)
.build();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException(e);
}
};
}
}
@EnableAutoConfiguration
@Configuration
protected static class SingleFunctionConfiguration {
@Bean
public Function<String, String> uppercase() {
@@ -229,9 +130,13 @@ public class CustomRuntimeEventLoopTest {
}
@EnableAutoConfiguration
@Component
@Component("personFunction") // need in test explicitly since it is inner class and name wil be `customRuntimeEventLoopTest.PersonFunction`
public static class PersonFunction implements Function<Person, Person> {
public PersonFunction() {
System.out.println();
}
@Override
public Person apply(Person input) {
return new Person(input.getName().toUpperCase());