From 6b6384a09ed643252e46551d2cf58dbf5b485c88 Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Thu, 9 Aug 2018 12:03:40 +0200 Subject: [PATCH] Improve WebFlux Protobuf support - Update javadoc for decoding default instances - Refactor and simplify tests - Add missing tests - Refactor decoding with flatMapIterable instead of concatMap and avoid recursive call Issue: SPR-15776 --- .../http/codec/protobuf/ProtobufDecoder.java | 87 +++++++------ .../http/codec/protobuf/ProtobufEncoder.java | 2 +- .../codec/protobuf/ProtobufDecoderTests.java | 123 ++++++++++-------- .../annotation/ProtobufIntegrationTests.java | 16 +++ 4 files changed, 134 insertions(+), 94 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java index 0b6218cb6b..5eb050637a 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java @@ -18,6 +18,7 @@ package org.springframework.http.codec.protobuf; import java.io.IOException; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; @@ -44,13 +45,17 @@ import org.springframework.util.MimeType; * A {@code Decoder} that reads {@link com.google.protobuf.Message}s * using Google Protocol Buffers. * - * Flux deserialized via + *

Flux deserialized via * {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use * delimited Protobuf messages * with the size of each message specified before the message itself. Single values deserialized * via {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected to use * regular Protobuf message format (without the size prepended before the message). * + *

Notice that default instance of Protobuf message produces empty byte array, so + * {@code Mono.just(Msg.getDefaultInstance())} sent over the network will be deserialized + * as an empty {@link Mono}. + * *

To generate {@code Message} Java classes, you need to install the {@code protoc} binary. * *

This decoder requires Protobuf 3 or higher, and supports @@ -108,7 +113,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder hints) { return Flux.from(inputStream) - .concatMap(new MessageDecoderFunction(elementType, this.maxMessageSize)); + .flatMapIterable(new MessageDecoderFunction(elementType, this.maxMessageSize)); } @Override @@ -152,7 +157,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder> { + private class MessageDecoderFunction implements Function> { private final ResolvableType elementType; @@ -163,55 +168,59 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder apply(DataBuffer input) { + public Iterable apply(DataBuffer input) { try { - if (this.output == null) { - int firstByte = input.read(); - if (firstByte == -1) { - return Flux.error(new DecodingException("Can't parse message size")); + List messages = new ArrayList<>(); + int remainingBytesToRead; + int chunkBytesToRead; + + do { + if (this.output == null) { + int firstByte = input.read(); + if (firstByte == -1) { + throw new DecodingException("Can't parse message size"); + } + this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream()); + if (this.messageBytesToRead > this.maxMessageSize) { + throw new DecodingException( + "The number of bytes to read parsed in the incoming stream (" + + this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")"); + } + this.output = input.factory().allocateBuffer(this.messageBytesToRead); } - this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream()); - if (this.messageBytesToRead > this.maxMessageSize) { - return Flux.error(new DecodingException( - "The number of bytes to read parsed in the incoming stream (" + - this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")")); + + chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ? + input.readableByteCount() : this.messageBytesToRead; + remainingBytesToRead = input.readableByteCount() - chunkBytesToRead; + + byte[] bytesToWrite = new byte[chunkBytesToRead]; + input.read(bytesToWrite, 0, chunkBytesToRead); + this.output.write(bytesToWrite); + this.messageBytesToRead -= chunkBytesToRead; + + if (this.messageBytesToRead == 0) { + Message.Builder builder = getMessageBuilder(this.elementType.toClass()); + builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry); + messages.add(builder.build()); + DataBufferUtils.release(this.output); + this.output = null; } - this.output = input.factory().allocateBuffer(this.messageBytesToRead); - } - int chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ? - input.readableByteCount() : this.messageBytesToRead; - int remainingBytesToRead = input.readableByteCount() - chunkBytesToRead; - this.output.write(input.slice(input.readPosition(), chunkBytesToRead)); - this.messageBytesToRead -= chunkBytesToRead; - Message message = null; - if (this.messageBytesToRead == 0) { - Message.Builder builder = getMessageBuilder(this.elementType.toClass()); - builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry); - message = builder.build(); - DataBufferUtils.release(this.output); - this.output = null; - } - if (remainingBytesToRead > 0) { - return Mono.justOrEmpty(message).concatWith( - apply(input.slice(input.readPosition() + chunkBytesToRead, remainingBytesToRead))); - } - else { - return Mono.justOrEmpty(message); - } + } while (remainingBytesToRead > 0); + return messages; } catch (IOException ex) { - return Flux.error(new DecodingException("I/O error while parsing input stream", ex)); + throw new DecodingException("I/O error while parsing input stream", ex); } catch (Exception ex) { - return Flux.error(new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex)); + throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex); } } } diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java index b32f792f6e..b70f2fc520 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java @@ -40,7 +40,7 @@ import org.springframework.util.MimeType; * An {@code Encoder} that writes {@link com.google.protobuf.Message}s * using Google Protocol Buffers. * - * Flux are serialized using + *

Flux are serialized using * delimited Protobuf messages * with the size of each message specified before the message itself. Single values are * serialized using regular Protobuf message format (without the size prepended before the message). diff --git a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java index 55a3018f89..d1118aa142 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java @@ -16,12 +16,7 @@ package org.springframework.http.codec.protobuf; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; import com.google.protobuf.Message; import org.junit.Before; @@ -47,8 +42,6 @@ import static org.springframework.core.ResolvableType.forClass; /** * Unit tests for {@link ProtobufDecoder}. - * TODO Make tests more readable - * TODO Add a test where an input DataBuffer is larger than a message * * @author Sebastien Deleuze */ @@ -56,7 +49,13 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase { private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf"); - private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + private final SecondMsg secondMsg = SecondMsg.newBuilder().setBlah(123).build(); + + private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(secondMsg).build(); + + private final SecondMsg secondMsg2 = SecondMsg.newBuilder().setBlah(456).build(); + + private final Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(secondMsg2).build(); private ProtobufDecoder decoder; @@ -82,51 +81,59 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase { @Test public void decodeToMono() { - byte[] body = this.testMsg.toByteArray(); - Flux source = Flux.just(this.bufferFactory.wrap(body)); + DataBuffer data = this.bufferFactory.wrap(testMsg.toByteArray()); ResolvableType elementType = forClass(Msg.class); - Mono mono = this.decoder.decodeToMono(source, elementType, null, - emptyMap()); + + Mono mono = this.decoder.decodeToMono(Flux.just(data), elementType, null, emptyMap()); StepVerifier.create(mono) - .expectNext(this.testMsg) + .expectNext(testMsg) + .verifyComplete(); + } + + @Test + public void decodeToMonoWithLargerDataBuffer() { + DataBuffer buffer = this.bufferFactory.allocateBuffer(1024); + buffer.write(testMsg.toByteArray()); + ResolvableType elementType = forClass(Msg.class); + + Mono mono = this.decoder.decodeToMono(Flux.just(buffer), elementType, null, emptyMap()); + + StepVerifier.create(mono) + .expectNext(testMsg) .verifyComplete(); } @Test public void decodeChunksToMono() { - byte[] body = this.testMsg.toByteArray(); - List chunks = new ArrayList<>(); - chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 0, 4))); - chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 4, body.length))); - Flux source = Flux.fromIterable(chunks); + DataBuffer buffer = this.bufferFactory.wrap(testMsg.toByteArray()); + Flux chunks = Flux.just( + buffer.slice(0, 4), + buffer.slice(4, buffer.readableByteCount() - 4)); + DataBufferUtils.retain(buffer); ResolvableType elementType = forClass(Msg.class); - Mono mono = this.decoder.decodeToMono(source, elementType, null, + + Mono mono = this.decoder.decodeToMono(chunks, elementType, null, emptyMap()); StepVerifier.create(mono) - .expectNext(this.testMsg) + .expectNext(testMsg) .verifyComplete(); } @Test public void decode() throws IOException { - Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build(); - DataBuffer buffer = bufferFactory.allocateBuffer(); - OutputStream outputStream = buffer.asOutputStream(); - this.testMsg.writeDelimitedTo(outputStream); - + testMsg.writeDelimitedTo(buffer.asOutputStream()); DataBuffer buffer2 = bufferFactory.allocateBuffer(); - OutputStream outputStream2 = buffer2.asOutputStream(); - testMsg2.writeDelimitedTo(outputStream2); - + testMsg2.writeDelimitedTo(buffer2.asOutputStream()); Flux source = Flux.just(buffer, buffer2); ResolvableType elementType = forClass(Msg.class); + Flux messages = this.decoder.decode(source, elementType, null, emptyMap()); StepVerifier.create(messages) - .expectNext(this.testMsg) + .expectNext(testMsg) .expectNext(testMsg2) .verifyComplete(); @@ -135,42 +142,50 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase { } @Test - public void decodeChunks() throws IOException { - Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build(); - List chunks = new ArrayList<>(); + public void decodeSplitChunks() throws IOException { + DataBuffer buffer = bufferFactory.allocateBuffer(); + testMsg.writeDelimitedTo(buffer.asOutputStream()); + DataBuffer buffer2 = bufferFactory.allocateBuffer(); + testMsg2.writeDelimitedTo(buffer2.asOutputStream()); + Flux chunks = Flux.just( + buffer.slice(0, 4), + buffer.slice(4, buffer.readableByteCount() - 4), + buffer2.slice(0, 2), + buffer2.slice(2, buffer2.readableByteCount() - 2)); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - this.testMsg.writeDelimitedTo(outputStream); - byte[] byteArray = outputStream.toByteArray(); - ByteArrayOutputStream outputStream2 = new ByteArrayOutputStream(); - testMsg2.writeDelimitedTo(outputStream2); - byte[] byteArray2 = outputStream2.toByteArray(); - - chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray, 0, 4))); - byte[] chunk2 = Arrays.copyOfRange(byteArray, 4, byteArray.length); - byte[] chunk3 = Arrays.copyOfRange(byteArray2, 0, 4); - byte[] combined = new byte[chunk2.length + chunk3.length]; - for (int i = 0; i < combined.length; ++i) - { - combined[i] = i < chunk2.length ? chunk2[i] : chunk3[i - chunk2.length]; - } - chunks.add(this.bufferFactory.wrap(combined)); - chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray2, 4, byteArray2.length))); - - Flux source = Flux.fromIterable(chunks); ResolvableType elementType = forClass(Msg.class); - Flux messages = this.decoder.decode(source, elementType, null, emptyMap()); + Flux messages = this.decoder.decode(chunks, elementType, null, emptyMap()); StepVerifier.create(messages) - .expectNext(this.testMsg) + .expectNext(testMsg) .expectNext(testMsg2) .verifyComplete(); + + DataBufferUtils.release(buffer); + DataBufferUtils.release(buffer2); + } + + @Test + public void decodeMergedChunks() throws IOException { + DataBuffer buffer = bufferFactory.allocateBuffer(); + testMsg.writeDelimitedTo(buffer.asOutputStream()); + testMsg.writeDelimitedTo(buffer.asOutputStream()); + + ResolvableType elementType = forClass(Msg.class); + Flux messages = this.decoder.decode(Mono.just(buffer), elementType, null, emptyMap()); + + StepVerifier.create(messages) + .expectNext(testMsg) + .expectNext(testMsg) + .verifyComplete(); + + DataBufferUtils.release(buffer); } @Test public void exceedMaxSize() { this.decoder.setMaxMessageSize(1); - byte[] body = this.testMsg.toByteArray(); + byte[] body = testMsg.toByteArray(); Flux source = Flux.just(this.bufferFactory.wrap(body)); ResolvableType elementType = forClass(Msg.class); Flux messages = this.decoder.decode(source, elementType, null, diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ProtobufIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ProtobufIntegrationTests.java index 60a0d9c3df..f542bea978 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ProtobufIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ProtobufIntegrationTests.java @@ -129,6 +129,17 @@ public class ProtobufIntegrationTests extends AbstractRequestMappingIntegrationT .verifyComplete(); } + @Test + public void defaultInstance() { + Mono result = this.webClient.get() + .uri("/default-instance") + .retrieve() + .bodyToMono(Msg.class); + + StepVerifier.create(result) + .verifyComplete(); + } + @RestController @SuppressWarnings("unused") static class ProtobufController { @@ -153,6 +164,11 @@ public class ProtobufIntegrationTests extends AbstractRequestMappingIntegrationT return Mono.empty(); } + @GetMapping("default-instance") + Mono defaultInstance() { + return Mono.just(Msg.getDefaultInstance()); + } + } @Configuration