Fix for reactive function in CustomRuntime

This commit is contained in:
Urs Keller
2023-05-09 17:27:22 +02:00
committed by Oleg Zhurakousky
parent 3b88030038
commit 2dcad9a88d
3 changed files with 46 additions and 45 deletions

View File

@@ -20,13 +20,17 @@ import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import com.amazonaws.services.lambda.runtime.Context;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.json.JsonMapper;
@@ -141,6 +145,43 @@ public final class AWSLambdaUtils {
}
}
@SuppressWarnings("unchecked")
public static byte[] generateOutputFromObject(Message<?> requestMessage, Object output, JsonMapper objectMapper, Type functionOutputType) {
Message<byte[]> responseMessage = null;
if (output instanceof Publisher<?>) {
List<Object> result = new ArrayList<>();
for (Object value : Flux.from((Publisher<?>) output).toIterable()) {
if (logger.isDebugEnabled()) {
logger.debug("Response value: " + value);
}
result.add(value);
}
if (result.size() > 1) {
output = result;
}
else if (result.size() == 1) {
output = result.get(0);
}
else {
output = null;
}
if (output instanceof Message<?> && ((Message<?>) output).getPayload() instanceof byte[]) {
responseMessage = (Message<byte[]>) output;
}
else if (output != null) {
if (logger.isDebugEnabled()) {
logger.debug("OUTPUT: " + output + " - " + output.getClass().getName());
}
byte[] payload = objectMapper.toJson(output);
responseMessage = MessageBuilder.withPayload(payload).build();
}
}
else {
responseMessage = (Message<byte[]>) output;
}
return generateOutput(requestMessage, responseMessage, objectMapper, functionOutputType);
}
@SuppressWarnings({ "rawtypes", "unchecked" })
public static byte[] generateOutput(Message requestMessage, Message<?> responseMessage,
JsonMapper objectMapper, Type functionOutputType) {

View File

@@ -157,14 +157,13 @@ public final class CustomRuntimeEventLoop implements SmartLifecycle {
}
System.setProperty("com.amazonaws.xray.traceHeader", traceId);
}
Object responseObject = function.apply(eventMessage);
Message<byte[]> responseMessage = (Message<byte[]>) function.apply(eventMessage);
if (responseMessage != null && logger.isDebugEnabled()) {
logger.debug("Reply from function: " + responseMessage);
if (responseObject != null && logger.isDebugEnabled()) {
logger.debug("Reply from function: " + responseObject);
}
byte[] outputBody = AWSLambdaUtils.generateOutput(eventMessage, responseMessage, mapper, function.getOutputType());
byte[] outputBody = AWSLambdaUtils.generateOutputFromObject(eventMessage, responseObject, mapper, function.getOutputType());
ResponseEntity<Object> result = rest.exchange(RequestEntity.post(URI.create(invocationUrl))
.header(USER_AGENT, USER_AGENT_VALUE)
.body(outputBody), Object.class);

View File

@@ -19,8 +19,6 @@ package org.springframework.cloud.function.adapter.aws;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import com.amazonaws.services.lambda.runtime.Context;
@@ -28,8 +26,6 @@ import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.databind.MapperFeature;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.boot.SpringApplication;
import org.springframework.cloud.function.context.FunctionCatalog;
@@ -44,7 +40,6 @@ import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
@@ -83,45 +78,11 @@ public class FunctionInvoker implements RequestStreamHandler {
.generateMessage(input, this.function.getInputType(), this.function.isSupplier(), jsonMapper, context);
Object response = this.function.apply(requestMessage);
byte[] responseBytes = this.buildResult(requestMessage, response);
byte[] responseBytes = AWSLambdaUtils.generateOutputFromObject(requestMessage, response, this.jsonMapper, function.getOutputType());
StreamUtils.copy(responseBytes, output);
// any exception should propagate
}
@SuppressWarnings("unchecked")
private byte[] buildResult(Message<?> requestMessage, Object output) throws IOException {
Message<byte[]> responseMessage = null;
if (output instanceof Publisher<?>) {
List<Object> result = new ArrayList<>();
for (Object value : Flux.from((Publisher<?>) output).toIterable()) {
if (logger.isDebugEnabled()) {
logger.debug("Response value: " + value);
}
result.add(value);
}
if (result.size() > 1) {
output = result;
}
else if (result.size() == 1) {
output = result.get(0);
}
else {
output = null;
}
if (output != null) {
if (logger.isDebugEnabled()) {
logger.debug("OUTPUT: " + output + " - " + output.getClass().getName());
}
byte[] payload = this.jsonMapper.toJson(output);
responseMessage = MessageBuilder.withPayload(payload).build();
}
}
else {
responseMessage = (Message<byte[]>) output;
}
return AWSLambdaUtils.generateOutput(requestMessage, responseMessage, this.jsonMapper, function.getOutputType());
}
private void start() {
Class<?> startClass = FunctionClassUtils.getStartClass();
String[] properties = new String[] {"--spring.cloud.function.web.export.enabled=false", "--spring.main.web-application-type=none"};