diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java index c15ca7fdbb..4484c4eaa9 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java @@ -163,6 +163,14 @@ public abstract class AbstractListenerReadPublisher implements Publisher { */ protected abstract void readingPaused(); + /** + * Invoked after an I/O read error from the underlying server or after a + * cancellation signal from the downstream consumer to allow sub-classes + * to discard any current cached data they might have. + * @since 5.1.2 + */ + protected abstract void discardData(); + // Private methods for use in State... @@ -416,7 +424,10 @@ public abstract class AbstractListenerReadPublisher implements Publisher { } void cancel(AbstractListenerReadPublisher publisher) { - if (!publisher.changeState(this, COMPLETED)) { + if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); + } + else { publisher.state.get().cancel(publisher); } } @@ -439,6 +450,7 @@ public abstract class AbstractListenerReadPublisher implements Publisher { void onError(AbstractListenerReadPublisher publisher, Throwable t) { if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); Subscriber s = publisher.subscriber; if (s != null) { s.onError(t); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java index 9383ab8912..f6cd1877d8 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java @@ -158,6 +158,9 @@ public abstract class AbstractListenerWriteProcessor implements Processor subscriber) { + // Technically, cancellation from the result subscriber should be propagated + // to the upstream subscription. In practice, HttpHandler server adapters + // don't have a reason to cancel the result subscription. this.resultPublisher.subscribe(subscriber); } @@ -176,8 +179,14 @@ public abstract class AbstractListenerWriteProcessor implements Processor implements Processor implements Processor implements Processor implements Processor void onNext(AbstractListenerWriteProcessor processor, T data) { - throw new IllegalStateException(toString()); + processor.discardData(data); + processor.cancel(); + processor.onError(new IllegalStateException("Illegal onNext without demand")); } public void onError(AbstractListenerWriteProcessor processor, Throwable ex) { if (processor.changeState(this, COMPLETED)) { + processor.discardCurrentData(); processor.writingComplete(); processor.resultPublisher.publishError(ex); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index 884f2ccd51..52b7de1690 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -29,9 +29,7 @@ import reactor.netty.Connection; import reactor.netty.http.server.HttpServerRequest; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.NettyDataBufferFactory; -import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; @@ -165,8 +163,7 @@ class ReactorServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - Flux body = this.request.receive().retain().map(this.bufferFactory::wrap); - return body.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + return this.request.receive().retain().map(this.bufferFactory::wrap); } @SuppressWarnings("unchecked") diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index e2d6ec2634..aebbc13f67 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -302,6 +302,11 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { // no-op } + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } + private class RequestBodyPublisherReadListener implements ReadListener { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 9d11809031..eae0792339 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -334,6 +334,7 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { boolean ready = ServletServerHttpResponse.this.isWritePossible(); int remaining = dataBuffer.readableByteCount(); if (ready && remaining > 0) { + // In case of IOException, onError handling should call discardData(DataBuffer).. int written = writeToOutputStream(dataBuffer); if (logger.isTraceEnabled()) { logger.trace(getLogPrefix() + "Wrote " + written + " of " + remaining + " bytes"); @@ -359,6 +360,11 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { protected void writingComplete() { bodyProcessor = null; } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 6c68f5a528..05e641a3ce 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -116,8 +116,7 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - return Flux.from(this.body) - .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + return Flux.from(this.body); } @SuppressWarnings("unchecked") @@ -201,6 +200,10 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { } } + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } } private static class UndertowDataBuffer implements PooledDataBuffer { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index a9533379cd..c0ca6d6c48 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -181,6 +181,7 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl // Track write listener calls from here on.. this.writePossible = false; + // In case of IOException, onError handling should call discardData(DataBuffer).. int total = buffer.remaining(); int written = writeByteBuffer(buffer); @@ -235,6 +236,11 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl cancel(); onError(ex); } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java index 7a48c6b311..6c054aaded 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java @@ -16,58 +16,95 @@ package org.springframework.http.server.reactive; -import java.io.IOException; - +import org.junit.Before; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.core.io.buffer.DataBuffer; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; /** - * Unit tests for {@link AbstractListenerReadPublisher} + * Unit tests for {@link AbstractListenerReadPublisher}. * * @author Violeta Georgieva - * @since 5.0 + * @author Rossen Stoyanchev */ public class ListenerReadPublisherTests { - @Test - @SuppressWarnings("unchecked") - public void testReceiveTwoRequestCallsWhenOnSubscribe() { - Subscriber subscriber = mock(Subscriber.class); - doAnswer(new SubscriptionAnswer()).when(subscriber).onSubscribe(isA(Subscription.class)); + private final TestListenerReadPublisher publisher = new TestListenerReadPublisher(); - TestListenerReadPublisher publisher = new TestListenerReadPublisher(); - publisher.subscribe(subscriber); - publisher.onDataAvailable(); + private final TestSubscriber subscriber = new TestSubscriber(); - assertTrue(publisher.getReadCalls() == 2); + + @Before + public void setup() { + this.publisher.subscribe(this.subscriber); } - private static final class TestListenerReadPublisher extends AbstractListenerReadPublisher { + + @Test + public void twoReads() { + + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + + assertEquals(2, this.publisher.getReadCalls()); + } + + @Test // SPR-17410 + public void discardDataOnError() { + + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + this.publisher.onError(new IllegalStateException()); + + assertEquals(2, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); + } + + @Test // SPR-17410 + public void discardDataOnCancel() { + + this.subscriber.getSubscription().request(2); + this.subscriber.setCancelOnNext(true); + this.publisher.onDataAvailable(); + + assertEquals(1, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); + } + + + private static final class TestListenerReadPublisher extends AbstractListenerReadPublisher { private int readCalls = 0; + private int discardCalls = 0; + + public TestListenerReadPublisher() { super(""); } + + public int getReadCalls() { + return this.readCalls; + } + + public int getDiscardCalls() { + return this.discardCalls; + } + @Override protected void checkOnDataAvailable() { // no-op } @Override - protected DataBuffer read() throws IOException { - readCalls++; + protected DataBuffer read() { + this.readCalls++; return mock(DataBuffer.class); } @@ -76,22 +113,48 @@ public class ListenerReadPublisherTests { // No-op } - public int getReadCalls() { - return this.readCalls; + @Override + protected void discardData() { + this.discardCalls++; } - } - private static final class SubscriptionAnswer implements Answer { - @Override - public Subscription answer(InvocationOnMock invocation) throws Throwable { - Subscription arg = (Subscription) invocation.getArguments()[0]; - arg.request(1); - arg.request(1); - return arg; + private static final class TestSubscriber implements Subscriber { + + private Subscription subscription; + + private boolean cancelOnNext; + + + public Subscription getSubscription() { + return this.subscription; } + public void setCancelOnNext(boolean cancelOnNext) { + this.cancelOnNext = cancelOnNext; + } + + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + } + + @Override + public void onNext(DataBuffer dataBuffer) { + if (this.cancelOnNext) { + this.subscription.cancel(); + } + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onComplete() { + } } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java new file mode 100644 index 0000000000..80348355bf --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; + +import static junit.framework.TestCase.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link AbstractListenerWriteProcessor}. + * + * @author Rossen Stoyanchev + */ +public class ListenerWriteProcessorTests { + + private final TestListenerWriteProcessor processor = new TestListenerWriteProcessor(); + + private final TestResultSubscriber resultSubscriber = new TestResultSubscriber(); + + private final TestSubscription subscription = new TestSubscription(); + + + @Before + public void setup() { + this.processor.subscribe(this.resultSubscriber); + this.processor.onSubscribe(this.subscription); + assertEquals(1, subscription.getDemand()); + } + + + @Test // SPR-17410 + public void writePublisherError() { + + // Turn off writing so next item will be cached + this.processor.setWritePossible(false); + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + // Send error while item cached + this.processor.onError(new IllegalStateException()); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void ioExceptionDuringWrite() { + + // Fail on next write + this.processor.setWritePossible(true); + this.processor.setFailOnWrite(true); + + // Write + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void onNextWithoutDemand() { + + // Disable writing: next item will be cached.. + this.processor.setWritePossible(false); + DataBuffer buffer1 = mock(DataBuffer.class); + this.processor.onNext(buffer1); + + // Send more data illegally + DataBuffer buffer2 = mock(DataBuffer.class); + this.processor.onNext(buffer2); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(2, this.processor.getDiscardedBuffers().size()); + assertSame(buffer2, this.processor.getDiscardedBuffers().get(0)); + assertSame(buffer1, this.processor.getDiscardedBuffers().get(1)); + } + + + private static final class TestListenerWriteProcessor extends AbstractListenerWriteProcessor { + + private final List discardedBuffers = new ArrayList<>(); + + private boolean writePossible; + + private boolean failOnWrite; + + + public List getDiscardedBuffers() { + return this.discardedBuffers; + } + + public void setWritePossible(boolean writePossible) { + this.writePossible = writePossible; + } + + public void setFailOnWrite(boolean failOnWrite) { + this.failOnWrite = failOnWrite; + } + + + @Override + protected boolean isDataEmpty(DataBuffer dataBuffer) { + return false; + } + + @Override + protected boolean isWritePossible() { + return this.writePossible; + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.failOnWrite) { + throw new IOException("write failed"); + } + return true; + } + + @Override + protected void writingFailed(Throwable ex) { + cancel(); + onError(ex); + } + + @Override + protected void discardData(DataBuffer dataBuffer) { + this.discardedBuffers.add(dataBuffer); + } + } + + + private static final class TestSubscription implements Subscription { + + private long demand; + + + public long getDemand() { + return this.demand; + } + + + @Override + public void request(long n) { + this.demand = (n == Long.MAX_VALUE ? n : this.demand + n); + } + + @Override + public void cancel() { + } + } + + private static final class TestResultSubscriber implements Subscriber { + + private Throwable error; + + + public Throwable getError() { + return this.error; + } + + + @Override + public void onSubscribe(Subscription subscription) { + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + this.error = ex; + } + + @Override + public void onComplete() { + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java index ee324460fc..1f1ab1d83c 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java @@ -260,11 +260,23 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc rsReadLogger.trace(getLogPrefix() + "Received " + message); } if (!this.pendingMessages.offer(message)) { + discardData(); throw new IllegalStateException( "Too many messages. Please ensure WebSocketSession.receive() is subscribed to."); } onDataAvailable(); } + + @Override + protected void discardData() { + while (true) { + WebSocketMessage message = (WebSocketMessage) this.pendingMessages.poll(); + if (message == null) { + return; + } + message.release(); + } + } } @@ -289,6 +301,7 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc else if (rsWriteLogger.isTraceEnabled()) { rsWriteLogger.trace(getLogPrefix() + "Sending " + message); } + // In case of IOException, onError handling should call discardData(WebSocketMessage).. return sendMessage(message); } @@ -313,6 +326,11 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc } this.isReady = ready; } + + @Override + protected void discardData(WebSocketMessage message) { + message.release(); + } } }