diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java index 9d0db71396..8d9ba54d90 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java @@ -172,6 +172,8 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder this.maxMessageSize) { throw new DecodingException( "The number of bytes to read from the incoming stream " + @@ -235,6 +235,57 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements DecoderBase 128 Varints + */ + private boolean readMessageSize(DataBuffer input) { + if (this.offset == 0) { + if (input.readableByteCount() == 0) { + return false; + } + int firstByte = input.read(); + if ((firstByte & 0x80) == 0) { + this.messageBytesToRead = firstByte; + return true; + } + this.messageBytesToRead = firstByte & 0x7f; + this.offset = 7; + } + + if (this.offset < 32) { + for (; this.offset < 32; this.offset += 7) { + if (input.readableByteCount() == 0) { + return false; + } + final int b = input.read(); + this.messageBytesToRead |= (b & 0x7f) << offset; + if ((b & 0x80) == 0) { + this.offset = 0; + return true; + } + } + } + // Keep reading up to 64 bits. + for (; this.offset < 64; this.offset += 7) { + if (input.readableByteCount() == 0) { + return false; + } + final int b = input.read(); + if ((b & 0x80) == 0) { + this.offset = 0; + return true; + } + } + this.offset = 0; + throw new DecodingException("Cannot parse message size: malformed varint"); + } } } diff --git a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java index d8796026ec..7d0be96995 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java @@ -128,7 +128,9 @@ public class ProtobufDecoderTests extends AbstractDecoderTestCase input = Flux.just(this.testMsg1, this.testMsg2) .flatMap(msg -> Mono.defer(() -> { DataBuffer buffer = this.bufferFactory.allocateBuffer(); @@ -158,6 +160,44 @@ public class ProtobufDecoderTests extends AbstractDecoderTestCase input = Flux.just(bigMessage, bigMessage) + .flatMap(msg -> Mono.defer(() -> { + DataBuffer buffer = this.bufferFactory.allocateBuffer(); + try { + msg.writeDelimitedTo(buffer.asOutputStream()); + return Mono.just(buffer); + } + catch (IOException e) { + release(buffer); + return Mono.error(e); + } + })) + .flatMap(buffer -> { + int len = 2; + Flux result = Flux.just( + DataBufferUtils.retain(buffer.slice(0, len)), + DataBufferUtils + .retain(buffer.slice(len, buffer.readableByteCount() - len)) + ); + release(buffer); + return result; + }); + + testDecode(input, Msg.class, step -> step + .expectNext(bigMessage) + .expectNext(bigMessage) + .verifyComplete()); + } + @Test public void decodeMergedChunks() throws IOException { DataBuffer buffer = this.bufferFactory.allocateBuffer();