Improve WebFlux Protobuf support

- Update javadoc for decoding default instances
 - Refactor and simplify tests
 - Add missing tests
 - Refactor decoding with flatMapIterable instead of
   concatMap and avoid recursive call

Issue: SPR-15776
This commit is contained in:
Sebastien Deleuze
2018-08-09 12:03:40 +02:00
parent 8e571decc1
commit 6b6384a09e
4 changed files with 134 additions and 94 deletions

View File

@@ -18,6 +18,7 @@ package org.springframework.http.codec.protobuf;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
@@ -44,13 +45,17 @@ import org.springframework.util.MimeType;
* A {@code Decoder} that reads {@link com.google.protobuf.Message}s
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
*
* Flux deserialized via
* <p>Flux deserialized via
* {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
* with the size of each message specified before the message itself. Single values deserialized
* via {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected to use
* regular Protobuf message format (without the size prepended before the message).
*
* <p>Notice that default instance of Protobuf message produces empty byte array, so
* {@code Mono.just(Msg.getDefaultInstance())} sent over the network will be deserialized
* as an empty {@link Mono}.
*
* <p>To generate {@code Message} Java classes, you need to install the {@code protoc} binary.
*
* <p>This decoder requires Protobuf 3 or higher, and supports
@@ -108,7 +113,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return Flux.from(inputStream)
.concatMap(new MessageDecoderFunction(elementType, this.maxMessageSize));
.flatMapIterable(new MessageDecoderFunction(elementType, this.maxMessageSize));
}
@Override
@@ -152,7 +157,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
}
private class MessageDecoderFunction implements Function<DataBuffer, Publisher<? extends Message>> {
private class MessageDecoderFunction implements Function<DataBuffer, Iterable<? extends Message>> {
private final ResolvableType elementType;
@@ -163,55 +168,59 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
private int messageBytesToRead;
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
this.elementType = elementType;
this.maxMessageSize = maxMessageSize;
}
// TODO Instead of the recursive call, loop over the current DataBuffer,
// produce a list of as many messages as are contained, and save any remaining bytes with flatMapIterable
@Override
public Publisher<? extends Message> apply(DataBuffer input) {
public Iterable<? extends Message> apply(DataBuffer input) {
try {
if (this.output == null) {
int firstByte = input.read();
if (firstByte == -1) {
return Flux.error(new DecodingException("Can't parse message size"));
List<Message> messages = new ArrayList<>();
int remainingBytesToRead;
int chunkBytesToRead;
do {
if (this.output == null) {
int firstByte = input.read();
if (firstByte == -1) {
throw new DecodingException("Can't parse message size");
}
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
if (this.messageBytesToRead > this.maxMessageSize) {
throw new DecodingException(
"The number of bytes to read parsed in the incoming stream (" +
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")");
}
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
}
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
if (this.messageBytesToRead > this.maxMessageSize) {
return Flux.error(new DecodingException(
"The number of bytes to read parsed in the incoming stream (" +
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")"));
chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
input.readableByteCount() : this.messageBytesToRead;
remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
byte[] bytesToWrite = new byte[chunkBytesToRead];
input.read(bytesToWrite, 0, chunkBytesToRead);
this.output.write(bytesToWrite);
this.messageBytesToRead -= chunkBytesToRead;
if (this.messageBytesToRead == 0) {
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
messages.add(builder.build());
DataBufferUtils.release(this.output);
this.output = null;
}
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
}
int chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
input.readableByteCount() : this.messageBytesToRead;
int remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
this.output.write(input.slice(input.readPosition(), chunkBytesToRead));
this.messageBytesToRead -= chunkBytesToRead;
Message message = null;
if (this.messageBytesToRead == 0) {
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
message = builder.build();
DataBufferUtils.release(this.output);
this.output = null;
}
if (remainingBytesToRead > 0) {
return Mono.justOrEmpty(message).concatWith(
apply(input.slice(input.readPosition() + chunkBytesToRead, remainingBytesToRead)));
}
else {
return Mono.justOrEmpty(message);
}
} while (remainingBytesToRead > 0);
return messages;
}
catch (IOException ex) {
return Flux.error(new DecodingException("I/O error while parsing input stream", ex));
throw new DecodingException("I/O error while parsing input stream", ex);
}
catch (Exception ex) {
return Flux.error(new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex));
throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex);
}
}
}

View File

@@ -40,7 +40,7 @@ import org.springframework.util.MimeType;
* An {@code Encoder} that writes {@link com.google.protobuf.Message}s
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
*
* Flux are serialized using
* <p>Flux are serialized using
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
* with the size of each message specified before the message itself. Single values are
* serialized using regular Protobuf message format (without the size prepended before the message).