diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 6c5449a644..0f242f781e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -19,8 +19,6 @@ package org.springframework.http.server.reactive; import java.io.IOException; import java.util.Map; import javax.servlet.AsyncContext; -import javax.servlet.AsyncEvent; -import javax.servlet.AsyncListener; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletRequest; @@ -39,7 +37,7 @@ import org.springframework.util.Assert; /** * Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async - * support and Servlet 3.1 Non-blocking I/O. + * support and Servlet 3.1 non-blocking I/O. * * @author Arjen Poutsma * @author Rossen Stoyanchev @@ -47,8 +45,7 @@ import org.springframework.util.Assert; */ @WebServlet(asyncSupported = true) @SuppressWarnings("serial") -public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport - implements Servlet { +public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport implements Servlet { private static final int DEFAULT_BUFFER_SIZE = 8192; @@ -78,30 +75,37 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport return this.dataBufferFactory; } + /** + * Set the size of the input buffer used for reading in bytes. + *

By default this is set to 8192. + */ public void setBufferSize(int bufferSize) { Assert.isTrue(bufferSize > 0); this.bufferSize = bufferSize; } + /** + * Return the configured input buffer size. + */ public int getBufferSize() { return this.bufferSize; } + @Override - public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException { + public void service(ServletRequest request, ServletResponse response) throws IOException { // Start async before Read/WriteListener registration - AsyncContext asyncContext = servletRequest.startAsync(); + AsyncContext asyncContext = request.startAsync(); - ServletServerHttpRequest request = new ServletServerHttpRequest( - ((HttpServletRequest) servletRequest), getDataBufferFactory(), getBufferSize()); - ServletServerHttpResponse response = new ServletServerHttpResponse( - ((HttpServletResponse) servletResponse), getDataBufferFactory(), getBufferSize()); + ServerHttpRequest httpRequest = new ServletServerHttpRequest( + ((HttpServletRequest) request), asyncContext, getDataBufferFactory(), getBufferSize()); - asyncContext.addListener(new EventHandlingAsyncListener(request, response)); + ServerHttpResponse httpResponse = new ServletServerHttpResponse( + ((HttpServletResponse) response), asyncContext, getDataBufferFactory(), getBufferSize()); - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(asyncContext); - getHttpHandler().handle(request, response).subscribe(resultSubscriber); + HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext); + getHttpHandler().handle(httpRequest, httpResponse).subscribe(subscriber); } // Other Servlet methods... @@ -129,6 +133,7 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport private final AsyncContext asyncContext; + public HandlerResultSubscriber(AsyncContext asyncContext) { this.asyncContext = asyncContext; } @@ -158,48 +163,4 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport } } - - private static final class EventHandlingAsyncListener implements AsyncListener { - - private final ServletServerHttpRequest request; - - private final ServletServerHttpResponse response; - - - public EventHandlingAsyncListener(ServletServerHttpRequest request, - ServletServerHttpResponse response) { - - this.request = request; - this.response = response; - } - - - @Override - public void onTimeout(AsyncEvent event) { - Throwable ex = event.getThrowable(); - if (ex == null) { - ex = new IllegalStateException("Async operation timeout."); - } - this.request.handleAsyncListenerError(ex); - this.response.handleAsyncListenerError(ex); - } - - @Override - public void onError(AsyncEvent event) { - this.request.handleAsyncListenerError(event.getThrowable()); - this.response.handleAsyncListenerError(event.getThrowable()); - } - - @Override - public void onStartAsync(AsyncEvent event) { - // no op - } - - @Override - public void onComplete(AsyncEvent event) { - this.request.handleAsyncListenerComplete(); - this.response.handleAsyncListenerComplete(); - } - } - } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 08d204fed1..33c6d5bd30 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -22,6 +22,9 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Enumeration; import java.util.Map; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.Cookie; @@ -51,18 +54,18 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { private final HttpServletRequest request; - private final DataBufferFactory dataBufferFactory; + private final DataBufferFactory bufferFactory; private final int bufferSize; - private final Object bodyPublisherMonitor = new Object(); + private final Object bodyPublisherLock = new Object(); private volatile RequestBodyPublisher bodyPublisher; private final Object cookieLock = new Object(); - public ServletServerHttpRequest(HttpServletRequest request, + public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize) { super(initUri(request), initHeaders(request)); @@ -71,8 +74,10 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); this.request = request; - this.dataBufferFactory = bufferFactory; + this.bufferFactory = bufferFactory; this.bufferSize = bufferSize; + + asyncContext.addListener(new RequestAsyncListener()); } private static URI initUri(HttpServletRequest request) { @@ -164,65 +169,78 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { try { - RequestBodyPublisher bodyPublisher = this.bodyPublisher; - if (bodyPublisher == null) { - synchronized (this.bodyPublisherMonitor) { - bodyPublisher = this.bodyPublisher; - if (bodyPublisher == null) { - bodyPublisher = createBodyPublisher(); - this.bodyPublisher = bodyPublisher; + RequestBodyPublisher publisher = this.bodyPublisher; + if (publisher == null) { + synchronized (this.bodyPublisherLock) { + publisher = this.bodyPublisher; + if (publisher == null) { + ServletInputStream inputStream = this.request.getInputStream(); + publisher = new RequestBodyPublisher(inputStream, this.bufferFactory, this.bufferSize); + publisher.registerReadListener(); + this.bodyPublisher = publisher; } } } - return Flux.from(bodyPublisher); + return Flux.from(publisher); } catch (IOException ex) { return Flux.error(ex); } } - /** Handle a timeout/error callback from the Servlet container */ - void handleAsyncListenerError(Throwable ex) { - if (this.bodyPublisher != null) { - this.bodyPublisher.onError(ex); + + private final class RequestAsyncListener implements AsyncListener { + + @Override + public void onStartAsync(AsyncEvent event) {} + + @Override + public void onTimeout(AsyncEvent event) { + Throwable ex = event.getThrowable(); + ex = ex != null ? ex : new IllegalStateException("Async operation timeout."); + handleError(ex); + } + + @Override + public void onError(AsyncEvent event) { + handleError(event.getThrowable()); + } + + private void handleError(Throwable ex) { + if (bodyPublisher != null) { + bodyPublisher.onError(ex); + } + } + + @Override + public void onComplete(AsyncEvent event) { + if (bodyPublisher != null) { + bodyPublisher.onAllDataRead(); + } } } - /** Handle a complete callback from the Servlet container */ - void handleAsyncListenerComplete() { - if (this.bodyPublisher != null) { - this.bodyPublisher.onAllDataRead(); - } - } - - private RequestBodyPublisher createBodyPublisher() throws IOException { - RequestBodyPublisher bodyPublisher = new RequestBodyPublisher( - this.request.getInputStream(), this.dataBufferFactory, this.bufferSize); - bodyPublisher.registerListener(); - return bodyPublisher; - } - - private static class RequestBodyPublisher extends AbstractListenerReadPublisher { - private final RequestBodyPublisher.RequestBodyReadListener readListener = - new RequestBodyPublisher.RequestBodyReadListener(); + private final RequestBodyPublisherReadListener readListener = + new RequestBodyPublisherReadListener(); private final ServletInputStream inputStream; - private final DataBufferFactory dataBufferFactory; + private final DataBufferFactory bufferFactory; private final byte[] buffer; + public RequestBodyPublisher(ServletInputStream inputStream, - DataBufferFactory dataBufferFactory, int bufferSize) { + DataBufferFactory bufferFactory, int bufferSize) { this.inputStream = inputStream; - this.dataBufferFactory = dataBufferFactory; + this.bufferFactory = bufferFactory; this.buffer = new byte[bufferSize]; } - public void registerListener() throws IOException { + public void registerReadListener() throws IOException { this.inputStream.setReadListener(this.readListener); } @@ -242,7 +260,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { } if (read > 0) { - DataBuffer dataBuffer = this.dataBufferFactory.allocateBuffer(read); + DataBuffer dataBuffer = this.bufferFactory.allocateBuffer(read); dataBuffer.write(this.buffer, 0, read); return dataBuffer; } @@ -251,7 +269,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { } - private class RequestBodyReadListener implements ReadListener { + private class RequestBodyPublisherReadListener implements ReadListener { @Override public void onDataAvailable() throws IOException { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 2d403add8f..ed4ab11bc9 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -22,6 +22,9 @@ import java.io.UncheckedIOException; import java.nio.charset.Charset; import java.util.List; import java.util.Map; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; import javax.servlet.http.Cookie; @@ -58,20 +61,22 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons private volatile ResponseBodyFlushProcessor bodyFlushProcessor; - public ServletServerHttpResponse(HttpServletResponse response, - DataBufferFactory dataBufferFactory, int bufferSize) throws IOException { + public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize) throws IOException { - super(dataBufferFactory); + super(bufferFactory); Assert.notNull(response, "HttpServletResponse must not be null"); - Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); - Assert.isTrue(bufferSize > 0, "Buffer size must be higher than 0"); + Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); + Assert.isTrue(bufferSize > 0, "'bufferSize' must be greater than 0"); this.response = response; this.bufferSize = bufferSize; + asyncContext.addListener(new ResponseAsyncListener()); + // Tomcat expects WriteListener registration on initial thread - registerListener(); + registerWriteListener(); } @@ -129,7 +134,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons return processor; } - private void registerListener() { + private void registerWriteListener() { try { outputStream().setWriteListener(this.writeListener); } @@ -159,31 +164,48 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } } - /** Handle a timeout/error callback from the Servlet container */ - void handleAsyncListenerError(Throwable ex) { - if (this.bodyFlushProcessor != null) { - this.bodyFlushProcessor.cancel(); - this.bodyFlushProcessor.onError(ex); + + private final class ResponseAsyncListener implements AsyncListener { + + @Override + public void onStartAsync(AsyncEvent event) {} + + @Override + public void onTimeout(AsyncEvent event) { + Throwable ex = event.getThrowable(); + ex = (ex != null ? ex : new IllegalStateException("Async operation timeout.")); + handleError(ex); } - if (this.bodyProcessor != null) { - this.bodyProcessor.cancel(); - this.bodyProcessor.onError(ex); + + @Override + public void onError(AsyncEvent event) { + handleError(event.getThrowable()); + } + + void handleError(Throwable ex) { + if (bodyFlushProcessor != null) { + bodyFlushProcessor.cancel(); + bodyFlushProcessor.onError(ex); + } + if (bodyProcessor != null) { + bodyProcessor.cancel(); + bodyProcessor.onError(ex); + } + } + + @Override + public void onComplete(AsyncEvent event) { + if (bodyFlushProcessor != null) { + bodyFlushProcessor.cancel(); + bodyFlushProcessor.onComplete(); + } + if (bodyProcessor != null) { + bodyProcessor.cancel(); + bodyProcessor.onComplete(); + } } } - /** Handle a complete callback from the Servlet container */ - void handleAsyncListenerComplete() { - if (this.bodyFlushProcessor != null) { - this.bodyFlushProcessor.cancel(); - this.bodyFlushProcessor.onComplete(); - } - if (this.bodyProcessor != null) { - this.bodyProcessor.cancel(); - this.bodyProcessor.onComplete(); - } - } - - private class ResponseBodyProcessor extends AbstractListenerWriteProcessor { private final ServletOutputStream outputStream; @@ -272,7 +294,6 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } } - private class ResponseBodyFlushProcessor extends AbstractListenerFlushProcessor { @Override diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java index c48115abbc..4b7d2aab27 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java @@ -18,12 +18,15 @@ package org.springframework.http.server.reactive; import java.util.Arrays; import java.util.Collections; +import javax.servlet.AsyncContext; import javax.servlet.http.HttpServletRequest; import org.junit.Test; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.mock.web.test.MockAsyncContext; import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.util.MultiValueMap; import static org.junit.Assert.assertEquals; @@ -72,9 +75,9 @@ public class ServerHttpRequestTests { } private ServerHttpRequest createHttpRequest(String path) throws Exception { - HttpServletRequest servletRequest = new MockHttpServletRequest("GET", path); - return new ServletServerHttpRequest(servletRequest, - new DefaultDataBufferFactory(), 1024); + HttpServletRequest request = new MockHttpServletRequest("GET", path); + AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse()); + return new ServletServerHttpRequest(request, asyncContext, new DefaultDataBufferFactory(), 1024); } }