GH-716 Add support for returning custom status code

Resolves #716
This commit is contained in:
Oleg Zhurakousky
2021-07-30 12:42:48 +02:00
parent ce1265d925
commit ffd4e43d4b
7 changed files with 146 additions and 31 deletions

View File

@@ -27,11 +27,6 @@
</properties>
<dependencies>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>amazon-kinesis-client</artifactId>
<version>1.14.4</version>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-function-context</artifactId>
@@ -73,6 +68,12 @@
<version>1.2.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>amazon-kinesis-client</artifactId>
<version>1.14.4</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-lambda-java-events</artifactId>
@@ -97,6 +98,7 @@
</exclusion>
</exclusions>
<optional>true</optional>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>

View File

@@ -29,7 +29,9 @@ import java.util.concurrent.atomic.AtomicReference;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPResponse;
import com.amazonaws.services.lambda.runtime.events.KinesisEvent;
import com.amazonaws.services.lambda.runtime.events.S3Event;
import com.amazonaws.services.lambda.runtime.events.SNSEvent;
@@ -163,7 +165,13 @@ final class AWSLambdaUtils {
@SuppressWarnings({ "rawtypes", "unchecked" })
public static byte[] generateOutput(Message requestMessage, Message<byte[]> responseMessage,
ObjectMapper objectMapper) {
ObjectMapper objectMapper, Type functionOutputType) {
Class<?> outputClass = FunctionTypeUtils.getRawType(functionOutputType);
if (outputClass != null && (APIGatewayV2HTTPResponse.class.isAssignableFrom(outputClass)
|| APIGatewayProxyResponseEvent.class.isAssignableFrom(outputClass))) {
return responseMessage.getPayload();
}
if (!objectMapper.isEnabled(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES)) {
@@ -190,7 +198,7 @@ final class AWSLambdaUtils {
}
String body = responseMessage == null
? "\"OK\"" : new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "\"");
? "\"OK\"" : new String(responseMessage.getPayload(), StandardCharsets.UTF_8).replaceAll("\\\"", "");
response.put("body", body);
if (responseMessage != null) {

View File

@@ -102,7 +102,7 @@ final class CustomRuntimeEventLoop {
logger.debug("Reply from function: " + responseMessage);
}
byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper);
byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper, function.getOutputType());
ResponseEntity<Object> result = rest
.exchange(RequestEntity.post(URI.create(invocationUrl)).body(outputBody), Object.class);

View File

@@ -26,6 +26,7 @@ import java.util.List;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -39,6 +40,7 @@ import org.springframework.cloud.function.context.config.RoutingFunction;
import org.springframework.cloud.function.utils.FunctionClassUtils;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
@@ -79,10 +81,22 @@ public class FunctionInvoker implements RequestStreamHandler {
Message requestMessage = AWSLambdaUtils
.generateMessage(payload, new MessageHeaders(Collections.emptyMap()), function.getInputType(), this.objectMapper, context);
Object response = this.function.apply(requestMessage);
try {
Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);
StreamUtils.copy(responseBytes, output);
}
catch (Exception e) {
logger.error(e);
StreamUtils.copy(this.buildExceptionResult(requestMessage, e), output);
}
}
byte[] responseBytes = this.buildResult(requestMessage, response);
StreamUtils.copy(responseBytes, output);
private byte[] buildExceptionResult(Message<?> requestMessage, Exception exception) throws IOException {
APIGatewayProxyResponseEvent event = new APIGatewayProxyResponseEvent();
event.setStatusCode(HttpStatus.EXPECTATION_FAILED.value());
event.setBody(exception.getMessage());
return this.objectMapper.writeValueAsBytes(event);
}
@SuppressWarnings("unchecked")
@@ -113,7 +127,7 @@ public class FunctionInvoker implements RequestStreamHandler {
else {
responseMessage = (Message<byte[]>) output;
}
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper);
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.objectMapper, function.getOutputType());
}
private void start() {

View File

@@ -21,13 +21,16 @@ import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPResponse;
import com.amazonaws.services.lambda.runtime.events.KinesisEvent;
import com.amazonaws.services.lambda.runtime.events.S3Event;
import com.amazonaws.services.lambda.runtime.events.SNSEvent;
@@ -46,7 +49,6 @@ import org.springframework.messaging.converter.AbstractMessageConverter;
import org.springframework.util.MimeType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;
/**
*
@@ -698,7 +700,7 @@ public class FunctionInvokerTests {
invoker.handleRequest(targetStream, output, null);
ObjectMapper mapper = new ObjectMapper();
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("\"HELLO\"");
assertThat(result.get("body")).isEqualTo("HELLO");
}
@SuppressWarnings("rawtypes")
@@ -713,7 +715,7 @@ public class FunctionInvokerTests {
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("\"JIM LAHEY\"");
assertThat(result.get("body")).isEqualTo("JIM LAHEY");
}
@SuppressWarnings("rawtypes")
@@ -729,7 +731,7 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("\"hello\"");
assertThat(result.get("body")).isEqualTo("hello");
}
@SuppressWarnings("rawtypes")
@@ -745,7 +747,7 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("\"Hello from Lambda\"");
assertThat(result.get("body")).isEqualTo("Hello from Lambda");
}
@SuppressWarnings("rawtypes")
@@ -761,9 +763,63 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("\"boom\"");
assertThat(result.get("body")).isEqualTo("boom");
}
@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayInAndOut() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "inputOutputApiEvent");
FunctionInvoker invoker = new FunctionInvoker();
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers.get("foo")).isEqualTo("bar");
}
@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayInAndOutV2() throws Exception {
System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
System.setProperty("spring.cloud.function.definition", "inputOutputApiEventV2");
FunctionInvoker invoker = new FunctionInvoker();
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("hello");
Map headers = (Map) result.get("headers");
assertThat(headers.get("foo")).isEqualTo("bar");
}
// @SuppressWarnings("rawtypes")
// @Test
// public void testApiGatewayInAndOutWithException() throws Exception {
// System.setProperty("MAIN_CLASS", ApiGatewayConfiguration.class.getName());
// System.setProperty("spring.cloud.function.definition", "inputOutputApiEventException");
// FunctionInvoker invoker = new FunctionInvoker();
//
// InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
// ByteArrayOutputStream output = new ByteArrayOutputStream();
// invoker.handleRequest(targetStream, output, null);
//
// Map result = mapper.readValue(output.toByteArray(), Map.class);
// assertThat(result.get("body")).isEqualTo("Intentional");
//
// Map headers = (Map) result.get("headers");
// assertThat(headers.get("foo")).isEqualTo("bar");
// }
@SuppressWarnings("rawtypes")
@Test
public void testApiGatewayEventAsMessage() throws Exception {
@@ -777,7 +833,7 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("\"hello\"");
assertThat(result.get("body")).isEqualTo("hello");
}
@SuppressWarnings("rawtypes")
@@ -793,7 +849,7 @@ public class FunctionInvokerTests {
Map result = mapper.readValue(output.toByteArray(), Map.class);
System.out.println(result);
assertThat(result.get("body")).isEqualTo("\"hello\"");
assertThat(result.get("body")).isEqualTo("hello");
}
@SuppressWarnings("rawtypes")
@@ -818,13 +874,9 @@ public class FunctionInvokerTests {
InputStream targetStream = new ByteArrayInputStream(this.apiGatewayEvent.getBytes());
ByteArrayOutputStream output = new ByteArrayOutputStream();
try {
invoker.handleRequest(targetStream, output, null);
fail();
}
catch (Exception e) {
// success, since no definition nor routing instructions are provided
}
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(((String) result.get("body"))).startsWith("Failed to establish route, since neither were provided:");
}
@SuppressWarnings("rawtypes")
@@ -839,7 +891,7 @@ public class FunctionInvokerTests {
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("\"olleh\"");
assertThat(result.get("body")).isEqualTo("olleh");
}
@SuppressWarnings("rawtypes")
@@ -855,7 +907,7 @@ public class FunctionInvokerTests {
invoker.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toByteArray(), Map.class);
assertThat(result.get("body")).isEqualTo("\"OLLEH\"");
assertThat(result.get("body")).isEqualTo("OLLEH");
}
@SuppressWarnings("unchecked")
@@ -1086,6 +1138,35 @@ public class FunctionInvokerTests {
};
}
@Bean
public Function<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> inputOutputApiEvent() {
return v -> {
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent();
response.setBody(v.getBody());
response.setStatusCode(200);
response.setHeaders(Collections.singletonMap("foo", "bar"));
return response;
};
}
@Bean
public Function<APIGatewayV2HTTPEvent, APIGatewayV2HTTPResponse> inputOutputApiEventV2() {
return v -> {
APIGatewayV2HTTPResponse response = new APIGatewayV2HTTPResponse();
response.setBody(v.getBody());
response.setStatusCode(200);
response.setHeaders(Collections.singletonMap("foo", "bar"));
return response;
};
}
@Bean
public Function<APIGatewayV2HTTPEvent, APIGatewayV2HTTPResponse> inputOutputApiEventException() {
return v -> {
throw new IllegalStateException("Intentional");
};
}
@Bean
public Function<APIGatewayV2HTTPEvent, String> inputApiV2Event() {
return v -> {