diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java
index 6c5449a644..0f242f781e 100644
--- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java
+++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java
@@ -19,8 +19,6 @@ package org.springframework.http.server.reactive;
import java.io.IOException;
import java.util.Map;
import javax.servlet.AsyncContext;
-import javax.servlet.AsyncEvent;
-import javax.servlet.AsyncListener;
import javax.servlet.Servlet;
import javax.servlet.ServletConfig;
import javax.servlet.ServletRequest;
@@ -39,7 +37,7 @@ import org.springframework.util.Assert;
/**
* Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async
- * support and Servlet 3.1 Non-blocking I/O.
+ * support and Servlet 3.1 non-blocking I/O.
*
* @author Arjen Poutsma
* @author Rossen Stoyanchev
@@ -47,8 +45,7 @@ import org.springframework.util.Assert;
*/
@WebServlet(asyncSupported = true)
@SuppressWarnings("serial")
-public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport
- implements Servlet {
+public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport implements Servlet {
private static final int DEFAULT_BUFFER_SIZE = 8192;
@@ -78,30 +75,37 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport
return this.dataBufferFactory;
}
+ /**
+ * Set the size of the input buffer used for reading in bytes.
+ *
By default this is set to 8192.
+ */
public void setBufferSize(int bufferSize) {
Assert.isTrue(bufferSize > 0);
this.bufferSize = bufferSize;
}
+ /**
+ * Return the configured input buffer size.
+ */
public int getBufferSize() {
return this.bufferSize;
}
+
@Override
- public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws IOException {
+ public void service(ServletRequest request, ServletResponse response) throws IOException {
// Start async before Read/WriteListener registration
- AsyncContext asyncContext = servletRequest.startAsync();
+ AsyncContext asyncContext = request.startAsync();
- ServletServerHttpRequest request = new ServletServerHttpRequest(
- ((HttpServletRequest) servletRequest), getDataBufferFactory(), getBufferSize());
- ServletServerHttpResponse response = new ServletServerHttpResponse(
- ((HttpServletResponse) servletResponse), getDataBufferFactory(), getBufferSize());
+ ServerHttpRequest httpRequest = new ServletServerHttpRequest(
+ ((HttpServletRequest) request), asyncContext, getDataBufferFactory(), getBufferSize());
- asyncContext.addListener(new EventHandlingAsyncListener(request, response));
+ ServerHttpResponse httpResponse = new ServletServerHttpResponse(
+ ((HttpServletResponse) response), asyncContext, getDataBufferFactory(), getBufferSize());
- HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(asyncContext);
- getHttpHandler().handle(request, response).subscribe(resultSubscriber);
+ HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext);
+ getHttpHandler().handle(httpRequest, httpResponse).subscribe(subscriber);
}
// Other Servlet methods...
@@ -129,6 +133,7 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport
private final AsyncContext asyncContext;
+
public HandlerResultSubscriber(AsyncContext asyncContext) {
this.asyncContext = asyncContext;
}
@@ -158,48 +163,4 @@ public class ServletHttpHandlerAdapter extends HttpHandlerAdapterSupport
}
}
-
- private static final class EventHandlingAsyncListener implements AsyncListener {
-
- private final ServletServerHttpRequest request;
-
- private final ServletServerHttpResponse response;
-
-
- public EventHandlingAsyncListener(ServletServerHttpRequest request,
- ServletServerHttpResponse response) {
-
- this.request = request;
- this.response = response;
- }
-
-
- @Override
- public void onTimeout(AsyncEvent event) {
- Throwable ex = event.getThrowable();
- if (ex == null) {
- ex = new IllegalStateException("Async operation timeout.");
- }
- this.request.handleAsyncListenerError(ex);
- this.response.handleAsyncListenerError(ex);
- }
-
- @Override
- public void onError(AsyncEvent event) {
- this.request.handleAsyncListenerError(event.getThrowable());
- this.response.handleAsyncListenerError(event.getThrowable());
- }
-
- @Override
- public void onStartAsync(AsyncEvent event) {
- // no op
- }
-
- @Override
- public void onComplete(AsyncEvent event) {
- this.request.handleAsyncListenerComplete();
- this.response.handleAsyncListenerComplete();
- }
- }
-
}
diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java
index 08d204fed1..33c6d5bd30 100644
--- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java
+++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java
@@ -22,6 +22,9 @@ import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.util.Enumeration;
import java.util.Map;
+import javax.servlet.AsyncContext;
+import javax.servlet.AsyncEvent;
+import javax.servlet.AsyncListener;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Cookie;
@@ -51,18 +54,18 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
private final HttpServletRequest request;
- private final DataBufferFactory dataBufferFactory;
+ private final DataBufferFactory bufferFactory;
private final int bufferSize;
- private final Object bodyPublisherMonitor = new Object();
+ private final Object bodyPublisherLock = new Object();
private volatile RequestBodyPublisher bodyPublisher;
private final Object cookieLock = new Object();
- public ServletServerHttpRequest(HttpServletRequest request,
+ public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext,
DataBufferFactory bufferFactory, int bufferSize) {
super(initUri(request), initHeaders(request));
@@ -71,8 +74,10 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0");
this.request = request;
- this.dataBufferFactory = bufferFactory;
+ this.bufferFactory = bufferFactory;
this.bufferSize = bufferSize;
+
+ asyncContext.addListener(new RequestAsyncListener());
}
private static URI initUri(HttpServletRequest request) {
@@ -164,65 +169,78 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
@Override
public Flux getBody() {
try {
- RequestBodyPublisher bodyPublisher = this.bodyPublisher;
- if (bodyPublisher == null) {
- synchronized (this.bodyPublisherMonitor) {
- bodyPublisher = this.bodyPublisher;
- if (bodyPublisher == null) {
- bodyPublisher = createBodyPublisher();
- this.bodyPublisher = bodyPublisher;
+ RequestBodyPublisher publisher = this.bodyPublisher;
+ if (publisher == null) {
+ synchronized (this.bodyPublisherLock) {
+ publisher = this.bodyPublisher;
+ if (publisher == null) {
+ ServletInputStream inputStream = this.request.getInputStream();
+ publisher = new RequestBodyPublisher(inputStream, this.bufferFactory, this.bufferSize);
+ publisher.registerReadListener();
+ this.bodyPublisher = publisher;
}
}
}
- return Flux.from(bodyPublisher);
+ return Flux.from(publisher);
}
catch (IOException ex) {
return Flux.error(ex);
}
}
- /** Handle a timeout/error callback from the Servlet container */
- void handleAsyncListenerError(Throwable ex) {
- if (this.bodyPublisher != null) {
- this.bodyPublisher.onError(ex);
+
+ private final class RequestAsyncListener implements AsyncListener {
+
+ @Override
+ public void onStartAsync(AsyncEvent event) {}
+
+ @Override
+ public void onTimeout(AsyncEvent event) {
+ Throwable ex = event.getThrowable();
+ ex = ex != null ? ex : new IllegalStateException("Async operation timeout.");
+ handleError(ex);
+ }
+
+ @Override
+ public void onError(AsyncEvent event) {
+ handleError(event.getThrowable());
+ }
+
+ private void handleError(Throwable ex) {
+ if (bodyPublisher != null) {
+ bodyPublisher.onError(ex);
+ }
+ }
+
+ @Override
+ public void onComplete(AsyncEvent event) {
+ if (bodyPublisher != null) {
+ bodyPublisher.onAllDataRead();
+ }
}
}
- /** Handle a complete callback from the Servlet container */
- void handleAsyncListenerComplete() {
- if (this.bodyPublisher != null) {
- this.bodyPublisher.onAllDataRead();
- }
- }
-
- private RequestBodyPublisher createBodyPublisher() throws IOException {
- RequestBodyPublisher bodyPublisher = new RequestBodyPublisher(
- this.request.getInputStream(), this.dataBufferFactory, this.bufferSize);
- bodyPublisher.registerListener();
- return bodyPublisher;
- }
-
-
private static class RequestBodyPublisher extends AbstractListenerReadPublisher {
- private final RequestBodyPublisher.RequestBodyReadListener readListener =
- new RequestBodyPublisher.RequestBodyReadListener();
+ private final RequestBodyPublisherReadListener readListener =
+ new RequestBodyPublisherReadListener();
private final ServletInputStream inputStream;
- private final DataBufferFactory dataBufferFactory;
+ private final DataBufferFactory bufferFactory;
private final byte[] buffer;
+
public RequestBodyPublisher(ServletInputStream inputStream,
- DataBufferFactory dataBufferFactory, int bufferSize) {
+ DataBufferFactory bufferFactory, int bufferSize) {
this.inputStream = inputStream;
- this.dataBufferFactory = dataBufferFactory;
+ this.bufferFactory = bufferFactory;
this.buffer = new byte[bufferSize];
}
- public void registerListener() throws IOException {
+ public void registerReadListener() throws IOException {
this.inputStream.setReadListener(this.readListener);
}
@@ -242,7 +260,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
}
if (read > 0) {
- DataBuffer dataBuffer = this.dataBufferFactory.allocateBuffer(read);
+ DataBuffer dataBuffer = this.bufferFactory.allocateBuffer(read);
dataBuffer.write(this.buffer, 0, read);
return dataBuffer;
}
@@ -251,7 +269,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
}
- private class RequestBodyReadListener implements ReadListener {
+ private class RequestBodyPublisherReadListener implements ReadListener {
@Override
public void onDataAvailable() throws IOException {
diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java
index 2d403add8f..ed4ab11bc9 100644
--- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java
+++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java
@@ -22,6 +22,9 @@ import java.io.UncheckedIOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
+import javax.servlet.AsyncContext;
+import javax.servlet.AsyncEvent;
+import javax.servlet.AsyncListener;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.Cookie;
@@ -58,20 +61,22 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons
private volatile ResponseBodyFlushProcessor bodyFlushProcessor;
- public ServletServerHttpResponse(HttpServletResponse response,
- DataBufferFactory dataBufferFactory, int bufferSize) throws IOException {
+ public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext,
+ DataBufferFactory bufferFactory, int bufferSize) throws IOException {
- super(dataBufferFactory);
+ super(bufferFactory);
Assert.notNull(response, "HttpServletResponse must not be null");
- Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null");
- Assert.isTrue(bufferSize > 0, "Buffer size must be higher than 0");
+ Assert.notNull(bufferFactory, "DataBufferFactory must not be null");
+ Assert.isTrue(bufferSize > 0, "'bufferSize' must be greater than 0");
this.response = response;
this.bufferSize = bufferSize;
+ asyncContext.addListener(new ResponseAsyncListener());
+
// Tomcat expects WriteListener registration on initial thread
- registerListener();
+ registerWriteListener();
}
@@ -129,7 +134,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons
return processor;
}
- private void registerListener() {
+ private void registerWriteListener() {
try {
outputStream().setWriteListener(this.writeListener);
}
@@ -159,31 +164,48 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons
}
}
- /** Handle a timeout/error callback from the Servlet container */
- void handleAsyncListenerError(Throwable ex) {
- if (this.bodyFlushProcessor != null) {
- this.bodyFlushProcessor.cancel();
- this.bodyFlushProcessor.onError(ex);
+
+ private final class ResponseAsyncListener implements AsyncListener {
+
+ @Override
+ public void onStartAsync(AsyncEvent event) {}
+
+ @Override
+ public void onTimeout(AsyncEvent event) {
+ Throwable ex = event.getThrowable();
+ ex = (ex != null ? ex : new IllegalStateException("Async operation timeout."));
+ handleError(ex);
}
- if (this.bodyProcessor != null) {
- this.bodyProcessor.cancel();
- this.bodyProcessor.onError(ex);
+
+ @Override
+ public void onError(AsyncEvent event) {
+ handleError(event.getThrowable());
+ }
+
+ void handleError(Throwable ex) {
+ if (bodyFlushProcessor != null) {
+ bodyFlushProcessor.cancel();
+ bodyFlushProcessor.onError(ex);
+ }
+ if (bodyProcessor != null) {
+ bodyProcessor.cancel();
+ bodyProcessor.onError(ex);
+ }
+ }
+
+ @Override
+ public void onComplete(AsyncEvent event) {
+ if (bodyFlushProcessor != null) {
+ bodyFlushProcessor.cancel();
+ bodyFlushProcessor.onComplete();
+ }
+ if (bodyProcessor != null) {
+ bodyProcessor.cancel();
+ bodyProcessor.onComplete();
+ }
}
}
- /** Handle a complete callback from the Servlet container */
- void handleAsyncListenerComplete() {
- if (this.bodyFlushProcessor != null) {
- this.bodyFlushProcessor.cancel();
- this.bodyFlushProcessor.onComplete();
- }
- if (this.bodyProcessor != null) {
- this.bodyProcessor.cancel();
- this.bodyProcessor.onComplete();
- }
- }
-
-
private class ResponseBodyProcessor extends AbstractListenerWriteProcessor {
private final ServletOutputStream outputStream;
@@ -272,7 +294,6 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons
}
}
-
private class ResponseBodyFlushProcessor extends AbstractListenerFlushProcessor {
@Override
diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java
index c48115abbc..4b7d2aab27 100644
--- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java
+++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java
@@ -18,12 +18,15 @@ package org.springframework.http.server.reactive;
import java.util.Arrays;
import java.util.Collections;
+import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
+import org.springframework.mock.web.test.MockAsyncContext;
import org.springframework.mock.web.test.MockHttpServletRequest;
+import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.util.MultiValueMap;
import static org.junit.Assert.assertEquals;
@@ -72,9 +75,9 @@ public class ServerHttpRequestTests {
}
private ServerHttpRequest createHttpRequest(String path) throws Exception {
- HttpServletRequest servletRequest = new MockHttpServletRequest("GET", path);
- return new ServletServerHttpRequest(servletRequest,
- new DefaultDataBufferFactory(), 1024);
+ HttpServletRequest request = new MockHttpServletRequest("GET", path);
+ AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse());
+ return new ServletServerHttpRequest(request, asyncContext, new DefaultDataBufferFactory(), 1024);
}
}