diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java
new file mode 100644
index 0000000000..8000c23ab9
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java
@@ -0,0 +1,466 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.sockjs.client;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.net.URI;
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+
+
+import io.undertow.client.ClientCallback;
+import io.undertow.client.ClientConnection;
+import io.undertow.client.ClientExchange;
+import io.undertow.client.ClientRequest;
+import io.undertow.client.ClientResponse;
+import io.undertow.client.UndertowClient;
+import io.undertow.util.AttachmentKey;
+import io.undertow.util.HeaderMap;
+import io.undertow.util.HttpString;
+import io.undertow.util.Methods;
+import io.undertow.util.StringReadChannelListener;
+import org.xnio.ByteBufferSlicePool;
+import org.xnio.ChannelListener;
+import org.xnio.ChannelListeners;
+import org.xnio.IoUtils;
+import org.xnio.OptionMap;
+import org.xnio.Options;
+import org.xnio.Pool;
+import org.xnio.Pooled;
+import org.xnio.Xnio;
+import org.xnio.XnioWorker;
+import org.xnio.channels.StreamSinkChannel;
+import org.xnio.channels.StreamSourceChannel;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.ResponseEntity;
+import org.springframework.util.Assert;
+import org.springframework.util.concurrent.SettableListenableFuture;
+import org.springframework.web.client.HttpServerErrorException;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketHandler;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.sockjs.SockJsException;
+import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
+import org.springframework.web.socket.sockjs.frame.SockJsFrame;
+
+/**
+ * An XHR transport based on Undertow's {@link io.undertow.client.UndertowClient}.
+ *
+ *
When used for testing purposes (e.g. load testing) or for specific use cases
+ * (like HTTPS configuration), a custom OptionMap should be provided:
+ *
+ *
+ * OptionMap optionMap = OptionMap.builder()
+ * .set(Options.WORKER_IO_THREADS, 8)
+ * .set(Options.TCP_NODELAY, true)
+ * .set(Options.KEEP_ALIVE, true)
+ * .set(Options.WORKER_NAME, "SockJSClient")
+ * .getMap();
+ *
+ * UndertowXhrTransport transport = new UndertowXhrTransport(optionMap);
+ *
+ *
+ * @author Brian Clozel
+ * @since 4.1.2
+ * @see org.xnio.Options
+ */
+public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTransport {
+
+ private static final AttachmentKey RESPONSE_BODY = AttachmentKey.create(String.class);
+
+ private final Pool bufferPool;
+
+ private final OptionMap optionMap;
+
+ private final XnioWorker worker;
+
+ private final UndertowClient httpClient;
+
+ public UndertowXhrTransport() throws IOException {
+ this(OptionMap.builder().parse(Options.WORKER_NAME, "SockJSClient").getMap());
+ }
+
+ public UndertowXhrTransport(OptionMap optionMap) throws IOException {
+ Assert.notNull(optionMap, "'optionMap' is required");
+ this.bufferPool = new ByteBufferSlicePool(1048, 1048);
+ this.optionMap = optionMap;
+ this.worker = Xnio.getInstance().createWorker(optionMap);
+ this.httpClient = UndertowClient.getInstance();
+ }
+
+ private static HttpHeaders toHttpHeaders(HeaderMap headerMap) {
+ HttpHeaders responseHeaders = new HttpHeaders();
+ Iterator names = headerMap.getHeaderNames().iterator();
+ while(names.hasNext()) {
+ HttpString name = names.next();
+ Iterator values = headerMap.get(name).iterator();
+ while(values.hasNext()) {
+ responseHeaders.add(name.toString(), values.next());
+ }
+ }
+ return responseHeaders;
+ }
+
+ private static void addHttpHeaders(ClientRequest request, HttpHeaders headers) {
+ HeaderMap headerMap = request.getRequestHeaders();
+ for (String name : headers.keySet()) {
+ for (String value : headers.get(name)) {
+ headerMap.add(HttpString.tryFromString(name), value);
+ }
+ }
+ }
+
+ /**
+ * Return Undertow's native HTTP client
+ */
+ public UndertowClient getHttpClient() {
+ return httpClient;
+ }
+
+ /**
+ * Return the {@link org.xnio.XnioWorker} backing the I/O operations for Undertow's HTTP client
+ * @see org.xnio.Xnio
+ */
+ public XnioWorker getWorker() {
+ return this.worker;
+ }
+
+ @Override
+ protected ResponseEntity executeInfoRequestInternal(URI infoUrl) {
+ return executeRequest(infoUrl, Methods.GET, getRequestHeaders(), null);
+ }
+
+ @Override
+ protected ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
+ return executeRequest(url, Methods.POST, headers, message.getPayload());
+ }
+
+ protected ResponseEntity executeRequest(URI url, HttpString method, HttpHeaders headers, String body) {
+
+ final CountDownLatch latch = new CountDownLatch(1);
+ final List responses = new CopyOnWriteArrayList();
+ try {
+ final ClientConnection connection = this.httpClient.connect(url, this.worker,
+ this.bufferPool, this.optionMap).get();
+ try {
+ final ClientRequest request = new ClientRequest().setMethod(method).setPath(url.getPath());
+ request.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.HOST), url.getHost());
+ if(body !=null && !body.isEmpty()) {
+ request.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.CONTENT_LENGTH), body.length());
+ }
+ addHttpHeaders(request, headers);
+ connection.sendRequest(request, createRequestCallback(body, responses, latch));
+
+ latch.await();
+ final ClientResponse response = responses.iterator().next();
+ HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
+ HttpHeaders responseHeaders = toHttpHeaders(response.getResponseHeaders());
+ String responseBody = response.getAttachment(RESPONSE_BODY);
+ return (responseBody != null ?
+ new ResponseEntity(responseBody, responseHeaders, status) :
+ new ResponseEntity(responseHeaders, status));
+ }
+ finally {
+ IoUtils.safeClose(connection);
+ }
+ }
+ catch (IOException ex) {
+ throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex);
+ }
+ catch(InterruptedException ex) {
+ throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex);
+ }
+
+ }
+
+ private ClientCallback createRequestCallback(final String body,
+ final List responses, final CountDownLatch latch) {
+
+ return new ClientCallback() {
+ @Override
+ public void completed(ClientExchange result) {
+ result.setResponseListener(new ClientCallback() {
+ @Override
+ public void completed(final ClientExchange result) {
+ responses.add(result.getResponse());
+
+ new StringReadChannelListener(result.getConnection().getBufferPool()) {
+ @Override
+ protected void stringDone(String string) {
+ result.getResponse().putAttachment(RESPONSE_BODY, string);
+ latch.countDown();
+ }
+
+ @Override
+ protected void error(IOException ex) {
+ onFailure(latch, ex);
+ }
+ }.setup(result.getResponseChannel());
+ }
+
+ @Override
+ public void failed(IOException ex) {
+ onFailure(latch, ex);
+ }
+ });
+ try {
+ if(body != null) {
+ result.getRequestChannel().write(ByteBuffer.wrap(body.getBytes()));
+ }
+ result.getRequestChannel().shutdownWrites();
+ if(!result.getRequestChannel().flush()) {
+ result.getRequestChannel().getWriteSetter()
+ .set(ChannelListeners.flushingChannelListener(null, null));
+ result.getRequestChannel().resumeWrites();
+ }
+ }
+ catch (IOException ex) {
+ onFailure(latch, ex);
+ }
+ }
+
+ @Override
+ public void failed(IOException ex) {
+ onFailure(latch, ex);
+ }
+
+ private void onFailure(final CountDownLatch latch, IOException ex) {
+ latch.countDown();
+ throw new SockJsTransportFailureException("Failed to execute request", null, ex);
+ }
+ };
+ }
+
+ @Override
+ protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl,
+ HttpHeaders handshakeHeaders, XhrClientSockJsSession session, SettableListenableFuture connectFuture) {
+
+ executeReceiveRequest(receiveUrl, handshakeHeaders, session, connectFuture);
+ }
+
+ private void executeReceiveRequest(final URI url, final HttpHeaders headers, final XhrClientSockJsSession session,
+ final SettableListenableFuture connectFuture) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Starting XHR receive request, url=" + url);
+ }
+
+ this.httpClient.connect(
+ new ClientCallback() {
+ @Override
+ public void completed(ClientConnection result) {
+ final ClientRequest httpRequest = new ClientRequest().setMethod(Methods.POST).setPath(url.getPath());
+ httpRequest.getRequestHeaders().add(HttpString.tryFromString(HttpHeaders.HOST), url.getHost());
+ addHttpHeaders(httpRequest, headers);
+ result.sendRequest(httpRequest, createConnectCallback(url, getRequestHeaders(), session, connectFuture));
+ }
+
+ @Override
+ public void failed(IOException ex) {
+ throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex);
+ }
+ },
+ url, this.worker, this.bufferPool, this.optionMap);
+
+ }
+
+ private ClientCallback createConnectCallback(final URI url, final HttpHeaders headers,
+ final XhrClientSockJsSession sockJsSession, final SettableListenableFuture connectFuture) {
+
+ return new ClientCallback() {
+ @Override
+ public void completed(final ClientExchange result) {
+
+ result.setResponseListener(new ClientCallback() {
+ @Override
+ public void completed(final ClientExchange result) {
+
+ ClientResponse response = result.getResponse();
+ if(response.getResponseCode() != 200) {
+ HttpStatus status = HttpStatus.valueOf(response.getResponseCode());
+ IoUtils.safeClose(result.getConnection());
+ onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status"));
+ }
+ else {
+ SockJsResponseListener listener = new SockJsResponseListener(result.getConnection(),
+ url, headers, sockJsSession, connectFuture);
+ listener.setup(result.getResponseChannel());
+ }
+ if (logger.isTraceEnabled()) {
+ logger.trace("XHR receive headers: " + toHttpHeaders(response.getResponseHeaders()));
+ }
+ try {
+ result.getRequestChannel().shutdownWrites();
+ if(!result.getRequestChannel().flush()) {
+ result.getRequestChannel().getWriteSetter()
+ .set(ChannelListeners.flushingChannelListener(null, null));
+ result.getRequestChannel().resumeWrites();
+ }
+ }
+ catch (IOException exc) {
+ IoUtils.safeClose(result.getConnection());
+ onFailure(exc);
+ }
+
+ }
+
+ @Override
+ public void failed(IOException exc) {
+ IoUtils.safeClose(result.getConnection());
+ onFailure(exc);
+ }
+ });
+ }
+
+ @Override
+ public void failed(IOException exc) {
+ onFailure(exc);
+ }
+
+ private void onFailure(Throwable failure) {
+ if (connectFuture.setException(failure)) {
+ return;
+ }
+ if (sockJsSession.isDisconnected()) {
+ sockJsSession.afterTransportClosed(null);
+ }
+ else {
+ sockJsSession.handleTransportError(failure);
+ sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
+ }
+ }
+ };
+
+ }
+
+ public class SockJsResponseListener implements ChannelListener {
+
+ private final ClientConnection connection;
+ private final URI url;
+ private final HttpHeaders headers;
+ private final XhrClientSockJsSession session;
+ private final SettableListenableFuture connectFuture;
+
+ private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+
+ public SockJsResponseListener(ClientConnection connection, URI url, HttpHeaders headers,
+ XhrClientSockJsSession sockJsSession, SettableListenableFuture connectFuture) {
+ this.connection = connection;
+ this.url = url;
+ this.headers = headers;
+ this.session = sockJsSession;
+ this.connectFuture = connectFuture;
+ }
+
+ public void setup(final StreamSourceChannel channel) {
+ channel.suspendReads();
+ channel.getReadSetter().set(this);
+ channel.resumeReads();
+ }
+
+ @Override
+ public void handleEvent(StreamSourceChannel channel) {
+ if (this.session.isDisconnected()) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("SockJS sockJsSession closed, closing response.");
+ }
+ IoUtils.safeClose(this.connection);
+ throw new SockJsException("Session closed.", this.session.getId(), null);
+ }
+
+ Pooled pooled = this.connection.getBufferPool().allocate();
+
+ try {
+ int r;
+ do {
+ ByteBuffer buffer = pooled.getResource();
+ buffer.clear();
+ r = channel.read(buffer);
+ buffer.flip();
+ if (r == 0) {
+ return;
+ }
+ else if (r == -1) {
+ onSuccess();
+ }
+ else {
+ while(buffer.hasRemaining()) {
+ int b = buffer.get();
+ if (b == '\n') {
+ handleFrame();
+ }
+ else {
+ this.outputStream.write(b);
+ }
+ }
+ }
+
+ } while (r > 0);
+ }
+ catch (IOException exc) {
+ onFailure(exc);
+ }
+ finally {
+ pooled.free();
+ }
+ }
+
+ private void handleFrame() {
+ byte[] bytes = this.outputStream.toByteArray();
+ this.outputStream.reset();
+ String content = new String(bytes, SockJsFrame.CHARSET);
+ if (logger.isTraceEnabled()) {
+ logger.trace("XHR content received: " + content);
+ }
+ if (!PRELUDE.equals(content)) {
+ this.session.handleFrame(new String(bytes, SockJsFrame.CHARSET));
+ }
+ }
+
+ public void onSuccess() {
+ if (this.outputStream.size() > 0) {
+ handleFrame();
+ }
+ if (logger.isTraceEnabled()) {
+ logger.trace("XHR receive request completed.");
+ }
+ IoUtils.safeClose(this.connection);
+ executeReceiveRequest(this.url, this.headers, this.session, this.connectFuture);
+ }
+
+ public void onFailure(Throwable failure) {
+ IoUtils.safeClose(this.connection);
+ if (connectFuture.setException(failure)) {
+ return;
+ }
+ if (this.session.isDisconnected()) {
+ this.session.afterTransportClosed(null);
+ }
+ else {
+ this.session.handleTransportError(failure);
+ this.session.afterTransportClosed(new CloseStatus(1006, failure.getMessage()));
+ }
+ }
+ }
+
+}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java
index 111741237d..381163fabc 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java
@@ -16,6 +16,8 @@
package org.springframework.web.socket;
+import java.io.IOException;
+
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.Servlet;
@@ -34,6 +36,9 @@ import org.springframework.util.Assert;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
+import org.xnio.ByteBufferSlicePool;
+import org.xnio.OptionMap;
+import org.xnio.Xnio;
import static io.undertow.servlet.Servlets.*;
@@ -65,15 +70,26 @@ public class UndertowTestServer implements WebSocketTestServer {
public void deployConfig(WebApplicationContext cxt, Filter... filters) {
Assert.state(this.port != -1, "setup() was never called");
DispatcherServletInstanceFactory servletFactory = new DispatcherServletInstanceFactory(cxt);
+ // manually building WebSocketDeploymentInfo in order to avoid class cast exceptions
+ // with tomcat's implementation when using undertow 1.1.0+
+ WebSocketDeploymentInfo info = new WebSocketDeploymentInfo();
+ try {
+ info.setWorker(Xnio.getInstance().createWorker(OptionMap.EMPTY));
+ info.setBuffers(new ByteBufferSlicePool(1024,1024));
+ }
+ catch (IOException ex) {
+ throw new IllegalStateException(ex);
+ }
+
DeploymentInfo servletBuilder = deployment()
.setClassLoader(UndertowTestServer.class.getClassLoader())
.setDeploymentName("undertow-websocket-test")
.setContextPath("/")
- .addServlet(servlet("DispatcherServlet", DispatcherServlet.class, servletFactory).addMapping("/"))
- .addServletContextAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME, new WebSocketDeploymentInfo());
+ .addServlet(servlet("DispatcherServlet", DispatcherServlet.class, servletFactory).addMapping("/").setAsyncSupported(true))
+ .addServletContextAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME, info);
for (final Filter filter : filters) {
String filterName = filter.getClass().getName();
- servletBuilder.addFilter(new FilterInfo(filterName, filter.getClass(), new FilterInstanceFactory(filter)));
+ servletBuilder.addFilter(new FilterInfo(filterName, filter.getClass(), new FilterInstanceFactory(filter)).setAsyncSupported(true));
for (DispatcherType type : DispatcherType.values()) {
servletBuilder.addFilterUrlMapping(filterName, "/*", type);
}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java
index 6ab8c78ed4..4cf0950e5d 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java
@@ -140,7 +140,6 @@ public abstract class AbstractSockJsIntegrationTests {
// Temporarily @Ignore failures caused by suspected Jetty bug
- @Ignore
@Test
public void echoWebSocket() throws Exception {
testEcho(100, createWebSocketTransport());
@@ -158,7 +157,6 @@ public abstract class AbstractSockJsIntegrationTests {
testEcho(100, xhrTransport);
}
- @Ignore
@Test
public void receiveOneMessageWebSocket() throws Exception {
testReceiveOneMessage(createWebSocketTransport());
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/UndertowSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/UndertowSockJsIntegrationTests.java
new file mode 100644
index 0000000000..f9b34ca7d6
--- /dev/null
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/UndertowSockJsIntegrationTests.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.sockjs.client;
+
+import java.io.IOException;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.socket.UndertowTestServer;
+import org.springframework.web.socket.WebSocketTestServer;
+import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+import org.springframework.web.socket.server.RequestUpgradeStrategy;
+import org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy;
+
+/**
+ * @author Brian Clozel
+ */
+public class UndertowSockJsIntegrationTests extends AbstractSockJsIntegrationTests {
+
+ @Override
+ protected Class> upgradeStrategyConfigClass() {
+ return UndertowTestConfig.class;
+ }
+
+ @Override
+ protected WebSocketTestServer createWebSocketTestServer() {
+ return new UndertowTestServer();
+ }
+
+ @Override
+ protected Transport createWebSocketTransport() {
+ return new WebSocketTransport(new StandardWebSocketClient());
+ }
+
+ @Override
+ protected AbstractXhrTransport createXhrTransport() {
+ try {
+ return new UndertowXhrTransport();
+ }
+ catch (IOException ex) {
+ throw new IllegalStateException("Could not create UndertowXhrTransport");
+ }
+ }
+
+ @Configuration
+ static class UndertowTestConfig {
+ @Bean
+ public RequestUpgradeStrategy upgradeStrategy() {
+ return new UndertowRequestUpgradeStrategy();
+ }
+ }
+}