Support Protobuf serialization in WebFlux

This commit introduces Protobuf support in WebFlux via dedicated
codecs.

Flux<Message> are serialized/deserialized using delimited Protobuf
messages with the size of each message specified before the message
itself. In that case, a "delimited=true" parameter is added to the
content type.

Mono<Message> are expected to use regular Protobuf message
format (without the size prepended before the message).

Related HttpMessageReader/Writer are automatically registered when the
"com.google.protobuf:protobuf-java" library is detected in the classpath,
and can be customized easily if needed via CodecConfigurer, for example
to specify protocol extensions via the ExtensionRegistry based
constructors.

Both "application/x-protobuf" and "application/octet-stream" mime types
are supported.

Issue: SPR-15776
This commit is contained in:
sdeleuze
2018-03-19 18:16:46 +01:00
committed by Sebastien Deleuze
parent 4475c67ba8
commit 36a07aa897
23 changed files with 2225 additions and 18 deletions

View File

@@ -0,0 +1,183 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.protobuf;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.google.protobuf.Message;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.MediaType;
import org.springframework.protobuf.Msg;
import org.springframework.protobuf.SecondMsg;
import org.springframework.util.MimeType;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.springframework.core.ResolvableType.forClass;
/**
* Unit tests for {@link ProtobufDecoder}.
* TODO Make tests more readable
* TODO Add a test where an input DataBuffer is larger than a message
*
* @author Sebastien Deleuze
*/
public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase {
private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf");
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build();
private ProtobufDecoder decoder;
@Before
public void setup() {
this.decoder = new ProtobufDecoder();
}
@Test(expected = IllegalArgumentException.class)
public void extensionRegistryNull() {
new ProtobufDecoder(null);
}
@Test
public void canDecode() {
assertTrue(this.decoder.canDecode(forClass(Msg.class), null));
assertTrue(this.decoder.canDecode(forClass(Msg.class), PROTOBUF_MIME_TYPE));
assertTrue(this.decoder.canDecode(forClass(Msg.class), MediaType.APPLICATION_OCTET_STREAM));
assertFalse(this.decoder.canDecode(forClass(Msg.class), MediaType.APPLICATION_JSON));
assertFalse(this.decoder.canDecode(forClass(Object.class), PROTOBUF_MIME_TYPE));
}
@Test
public void decodeToMono() {
byte[] body = this.testMsg.toByteArray();
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
ResolvableType elementType = forClass(Msg.class);
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
emptyMap());
StepVerifier.create(mono)
.expectNext(this.testMsg)
.verifyComplete();
}
@Test
public void decodeChunksToMono() {
byte[] body = this.testMsg.toByteArray();
List<DataBuffer> chunks = new ArrayList<>();
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 0, 4)));
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 4, body.length)));
Flux<DataBuffer> source = Flux.fromIterable(chunks);
ResolvableType elementType = forClass(Msg.class);
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
emptyMap());
StepVerifier.create(mono)
.expectNext(this.testMsg)
.verifyComplete();
}
@Test
public void decode() throws IOException {
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
DataBuffer buffer = bufferFactory.allocateBuffer();
OutputStream outputStream = buffer.asOutputStream();
this.testMsg.writeDelimitedTo(outputStream);
DataBuffer buffer2 = bufferFactory.allocateBuffer();
OutputStream outputStream2 = buffer2.asOutputStream();
testMsg2.writeDelimitedTo(outputStream2);
Flux<DataBuffer> source = Flux.just(buffer, buffer2);
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
StepVerifier.create(messages)
.expectNext(this.testMsg)
.expectNext(testMsg2)
.verifyComplete();
DataBufferUtils.release(buffer);
DataBufferUtils.release(buffer2);
}
@Test
public void decodeChunks() throws IOException {
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
List<DataBuffer> chunks = new ArrayList<>();
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
this.testMsg.writeDelimitedTo(outputStream);
byte[] byteArray = outputStream.toByteArray();
ByteArrayOutputStream outputStream2 = new ByteArrayOutputStream();
testMsg2.writeDelimitedTo(outputStream2);
byte[] byteArray2 = outputStream2.toByteArray();
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray, 0, 4)));
byte[] chunk2 = Arrays.copyOfRange(byteArray, 4, byteArray.length);
byte[] chunk3 = Arrays.copyOfRange(byteArray2, 0, 4);
byte[] combined = new byte[chunk2.length + chunk3.length];
for (int i = 0; i < combined.length; ++i)
{
combined[i] = i < chunk2.length ? chunk2[i] : chunk3[i - chunk2.length];
}
chunks.add(this.bufferFactory.wrap(combined));
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray2, 4, byteArray2.length)));
Flux<DataBuffer> source = Flux.fromIterable(chunks);
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
StepVerifier.create(messages)
.expectNext(this.testMsg)
.expectNext(testMsg2)
.verifyComplete();
}
@Test
public void exceedMaxSize() {
this.decoder.setMaxMessageSize(1);
byte[] body = this.testMsg.toByteArray();
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null,
emptyMap());
StepVerifier.create(messages)
.verifyError(DecodingException.class);
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.protobuf;
import java.io.IOException;
import java.io.UncheckedIOException;
import com.google.protobuf.Message;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.MediaType;
import org.springframework.protobuf.Msg;
import org.springframework.protobuf.SecondMsg;
import org.springframework.util.MimeType;
import static java.util.Collections.emptyMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.springframework.core.ResolvableType.forClass;
/**
* Unit tests for {@link ProtobufEncoder}.
*
* @author Sebastien Deleuze
*/
public class ProtobufEncoderTests extends AbstractDataBufferAllocatingTestCase {
private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf");
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build();
private final ProtobufEncoder encoder = new ProtobufEncoder();
@Test
public void canEncode() {
assertTrue(this.encoder.canEncode(forClass(Msg.class), null));
assertTrue(this.encoder.canEncode(forClass(Msg.class), PROTOBUF_MIME_TYPE));
assertTrue(this.encoder.canEncode(forClass(Msg.class), MediaType.APPLICATION_OCTET_STREAM));
assertFalse(this.encoder.canEncode(forClass(Msg.class), MediaType.APPLICATION_JSON));
assertFalse(this.encoder.canEncode(forClass(Object.class), PROTOBUF_MIME_TYPE));
}
@Test
public void encode() {
Mono<Message> message = Mono.just(this.testMsg);
ResolvableType elementType = forClass(Msg.class);
Flux<DataBuffer> output = this.encoder.encode(message, this.bufferFactory, elementType, PROTOBUF_MIME_TYPE, emptyMap());
StepVerifier.create(output)
.consumeNextWith(dataBuffer -> {
try {
assertEquals(this.testMsg, Msg.parseFrom(dataBuffer.asInputStream()));
DataBufferUtils.release(dataBuffer);
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
})
.verifyComplete();
}
@Test
public void encodeStream() {
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
Flux<Message> messages = Flux.just(this.testMsg, testMsg2);
ResolvableType elementType = forClass(Msg.class);
Flux<DataBuffer> output = this.encoder.encode(messages, this.bufferFactory, elementType, PROTOBUF_MIME_TYPE, emptyMap());
StepVerifier.create(output)
.consumeNextWith(dataBuffer -> {
try {
assertEquals(this.testMsg, Msg.parseDelimitedFrom(dataBuffer.asInputStream()));
DataBufferUtils.release(dataBuffer);
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
})
.consumeNextWith(dataBuffer -> {
try {
assertEquals(testMsg2, Msg.parseDelimitedFrom(dataBuffer.asInputStream()));
DataBufferUtils.release(dataBuffer);
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
})
.verifyComplete();
}
}

View File

@@ -53,6 +53,8 @@ import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.json.Jackson2SmileDecoder;
import org.springframework.http.codec.json.Jackson2SmileEncoder;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.util.MimeTypeUtils;
@@ -75,12 +77,13 @@ public class ClientCodecConfigurerTests {
@Test
public void defaultReaders() {
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertEquals(11, readers.size());
assertEquals(12, readers.size());
assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass());
assertStringDecoder(getNextDecoder(readers), true);
assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass());
assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); // SPR-16804
assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass());
assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass());
@@ -92,13 +95,14 @@ public class ClientCodecConfigurerTests {
@Test
public void defaultWriters() {
List<HttpMessageWriter<?>> writers = this.configurer.getWriters();
assertEquals(10, writers.size());
assertEquals(11, writers.size());
assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertStringEncoder(getNextEncoder(writers), true);
assertEquals(MultipartHttpMessageWriter.class, writers.get(this.index.getAndIncrement()).getClass());
assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass());

View File

@@ -19,6 +19,7 @@ package org.springframework.http.codec.support;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.protobuf.ExtensionRegistry;
import org.junit.Test;
import org.springframework.core.ResolvableType;
@@ -45,6 +46,9 @@ import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.json.Jackson2SmileDecoder;
import org.springframework.http.codec.json.Jackson2SmileEncoder;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufEncoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.util.MimeTypeUtils;
@@ -56,6 +60,7 @@ import static org.springframework.core.ResolvableType.forClass;
/**
* Unit tests for {@link BaseDefaultCodecs}.
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class CodecConfigurerTests {
@@ -67,12 +72,13 @@ public class CodecConfigurerTests {
@Test
public void defaultReaders() {
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertEquals(10, readers.size());
assertEquals(11, readers.size());
assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass());
assertStringDecoder(getNextDecoder(readers), true);
assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass());
assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass());
assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass());
assertEquals(Jackson2SmileDecoder.class, getNextDecoder(readers).getClass());
@@ -83,12 +89,13 @@ public class CodecConfigurerTests {
@Test
public void defaultWriters() {
List<HttpMessageWriter<?>> writers = this.configurer.getWriters();
assertEquals(9, writers.size());
assertEquals(10, writers.size());
assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertStringEncoder(getNextEncoder(writers), true);
assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass());
@@ -117,12 +124,13 @@ public class CodecConfigurerTests {
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertEquals(14, readers.size());
assertEquals(15, readers.size());
assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass());
assertEquals(StringDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass());
assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass());
assertSame(customDecoder1, getNextDecoder(readers));
assertSame(customReader1, readers.get(this.index.getAndIncrement()));
@@ -156,12 +164,13 @@ public class CodecConfigurerTests {
List<HttpMessageWriter<?>> writers = this.configurer.getWriters();
assertEquals(13, writers.size());
assertEquals(14, writers.size());
assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertEquals(CharSequenceEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertSame(customEncoder1, getNextEncoder(writers));
assertSame(customWriter1, writers.get(this.index.getAndIncrement()));
assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass());
@@ -247,6 +256,19 @@ public class CodecConfigurerTests {
.filter(e -> e == decoder).orElse(null));
}
@Test
public void protobufDecoderOverride() {
ProtobufDecoder decoder = new ProtobufDecoder(ExtensionRegistry.newInstance());
this.configurer.defaultCodecs().protobufDecoder(decoder);
assertSame(decoder, this.configurer.getReaders().stream()
.filter(writer -> writer instanceof DecoderHttpMessageReader)
.map(writer -> ((DecoderHttpMessageReader<?>) writer).getDecoder())
.filter(e -> ProtobufDecoder.class.equals(e.getClass()))
.findFirst()
.filter(e -> e == decoder).orElse(null));
}
@Test
public void jackson2EncoderOverride() {
Jackson2JsonEncoder encoder = new Jackson2JsonEncoder();
@@ -260,6 +282,20 @@ public class CodecConfigurerTests {
.filter(e -> e == encoder).orElse(null));
}
@Test
public void protobufWriterOverride() {
ProtobufEncoder encoder = new ProtobufEncoder();
ProtobufHttpMessageWriter messageWriter = new ProtobufHttpMessageWriter(encoder);
this.configurer.defaultCodecs().protobufWriter(messageWriter);
assertSame(encoder, this.configurer.getWriters().stream()
.filter(writer -> writer instanceof EncoderHttpMessageWriter)
.map(writer -> ((EncoderHttpMessageWriter<?>) writer).getEncoder())
.filter(e -> ProtobufEncoder.class.equals(e.getClass()))
.findFirst()
.filter(e -> e == encoder).orElse(null));
}
private Decoder<?> getNextDecoder(List<HttpMessageReader<?>> readers) {
HttpMessageReader<?> reader = readers.get(this.index.getAndIncrement());

View File

@@ -54,6 +54,8 @@ import org.springframework.http.codec.json.Jackson2SmileDecoder;
import org.springframework.http.codec.json.Jackson2SmileEncoder;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.util.MimeTypeUtils;
@@ -76,12 +78,13 @@ public class ServerCodecConfigurerTests {
@Test
public void defaultReaders() {
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertEquals(12, readers.size());
assertEquals(13, readers.size());
assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass());
assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass());
assertStringDecoder(getNextDecoder(readers), true);
assertEquals(ProtobufDecoder.class, getNextDecoder(readers).getClass());
assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass());
assertEquals(SynchronossPartHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass());
assertEquals(MultipartHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass());
@@ -94,12 +97,13 @@ public class ServerCodecConfigurerTests {
@Test
public void defaultWriters() {
List<HttpMessageWriter<?>> writers = this.configurer.getWriters();
assertEquals(10, writers.size());
assertEquals(11, writers.size());
assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass());
assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertStringEncoder(getNextEncoder(writers), true);
assertEquals(ProtobufHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass());
assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jackson2SmileEncoder.class, getNextEncoder(writers).getClass());
assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass());