diff --git a/build.gradle b/build.gradle index 6579189eef..14c58a4df1 100644 --- a/build.gradle +++ b/build.gradle @@ -670,6 +670,7 @@ project("spring-websocket") { exclude group: "javax.servlet", module: "javax.servlet" } optional("org.eclipse.jetty.websocket:websocket-client:${jettyVersion}") + optional("org.eclipse.jetty:jetty-client:${jettyVersion}") optional("io.undertow:undertow-core:1.0.15.Final") optional("io.undertow:undertow-servlet:1.0.15.Final") { exclude group: "org.jboss.spec.javax.servlet", module: "jboss-servlet-api_3.1_spec" diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java index 364a9597b7..9733e7f2a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java @@ -217,7 +217,7 @@ public final class CloseStatus { @Override public String toString() { - return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]"; + return "CloseStatus[code=" + this.code + ", reason=" + this.reason + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index ea08db729a..dbec02cd99 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -73,10 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { WebSocketHttpHeaders headers, URI uri) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); - Assert.notNull(uri, "uri must not be null"); - - String scheme = uri.getScheme(); - Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme); + assertUri(uri); if (logger.isDebugEnabled()) { logger.debug("Connecting to " + uri); @@ -101,6 +98,12 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { Collections.emptyMap()); } + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme); + } + /** * Perform the actual handshake to establish a connection to the server. * diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index c82ebf8928..417955a47d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -194,7 +194,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Failed to parse WebSocket message to STOMP frame(s)", ex); + logger.error("Failed to parse WebSocket message to STOMP." + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); return; } @@ -232,7 +233,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Terminating STOMP session due to failure to send message", ex); + logger.error("Parsed STOMP message but could not send it to to message channel. " + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); } } @@ -248,7 +250,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR); headerAccessor.setMessage(error.getMessage()); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); @@ -331,7 +332,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE throw ex; } catch (Throwable ex) { - sendErrorMessage(session, ex); + logger.error("Failed to send WebSocket message to client, sessionId=" + session.getId(), ex); + command = StompCommand.ERROR; } finally { if (StompCommand.ERROR.equals(command)) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java new file mode 100644 index 0000000000..16217fa03c --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java @@ -0,0 +1,338 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.frame.SockJsFrameType; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; + +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Base class for SockJS client implementations of {@link WebSocketSession}. + * Provides processing of incoming SockJS message frames and delegates lifecycle + * events and messages to the (application) {@link WebSocketHandler}. + * Sub-classes implement actual send as well as disconnect logic. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractClientSockJsSession implements WebSocketSession { + + protected final Log logger = LogFactory.getLog(getClass()); + + + private final TransportRequest request; + + private final WebSocketHandler webSocketHandler; + + private final SettableListenableFuture connectFuture; + + + private final Map attributes = new ConcurrentHashMap(); + + private volatile State state = State.NEW; + + private volatile CloseStatus closeStatus; + + + protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + Assert.notNull(request, "'request' is required"); + Assert.notNull(handler, "'handler' is required"); + Assert.notNull(connectFuture, "'connectFuture' is required"); + this.request = request; + this.webSocketHandler = handler; + this.connectFuture = connectFuture; + } + + + @Override + public String getId() { + return this.request.getSockJsUrlInfo().getSessionId(); + } + + @Override + public URI getUri() { + return this.request.getSockJsUrlInfo().getSockJsUrl(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.request.getHandshakeHeaders(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Principal getPrincipal() { + return this.request.getUser(); + } + + public SockJsMessageCodec getMessageCodec() { + return this.request.getMessageCodec(); + } + + public WebSocketHandler getWebSocketHandler() { + return this.webSocketHandler; + } + + /** + * Return a timeout cleanup task to invoke if the SockJS sessions is not + * fully established within the retransmission timeout period calculated in + * {@code SockJsRequest} based on the duration of the initial SockJS "Info" + * request. + */ + Runnable getTimeoutTask() { + return new Runnable() { + @Override + public void run() { + closeInternal(new CloseStatus(2007, "Transport timed out")); + } + }; + } + + @Override + public boolean isOpen() { + return State.OPEN.equals(this.state); + } + + public boolean isDisconnected() { + return (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)); + } + + @Override + public final void sendMessage(WebSocketMessage message) throws IOException { + Assert.state(State.OPEN.equals(this.state), this + " is not open, current state=" + this.state); + Assert.isInstanceOf(TextMessage.class, message, this + " supports text messages only."); + String payload = ((TextMessage) message).getPayload(); + payload = getMessageCodec().encode(new String[] { payload }); + payload = payload.substring(1); // the client-side doesn't need message framing (letter "a") + message = new TextMessage(payload); + if (logger.isTraceEnabled()) { + logger.trace("Sending message " + message + " in " + this); + } + sendInternal((TextMessage) message); + } + + protected abstract void sendInternal(TextMessage textMessage) throws IOException; + + @Override + public final void close() throws IOException { + close(CloseStatus.NORMAL); + } + + @Override + public final void close(CloseStatus status) { + Assert.isTrue(status != null && isUserSetStatus(status), "Invalid close status: " + status); + if (logger.isInfoEnabled()) { + logger.info("Closing session with " + status + " in " + this); + } + closeInternal(status); + } + + private boolean isUserSetStatus(CloseStatus status) { + return (status.getCode() == 1000 || (status.getCode() >= 3000 && status.getCode() <= 4999)); + } + + protected void closeInternal(CloseStatus status) { + if (this.state == null) { + logger.warn("Ignoring close since connect() was never invoked"); + return; + } + if (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)) { + logger.debug("Ignoring close (already closing or closed), current state=" + this.state); + return; + } + this.state = State.CLOSING; + this.closeStatus = status; + try { + disconnect(status); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to close " + this, ex); + } + } + } + + protected abstract void disconnect(CloseStatus status) throws IOException; + + public void handleFrame(String payload) { + SockJsFrame frame = new SockJsFrame(payload); + if (SockJsFrameType.OPEN.equals(frame.getType())) { + handleOpenFrame(); + } + else if (SockJsFrameType.MESSAGE.equals(frame.getType())) { + handleMessageFrame(frame); + } + else if (SockJsFrameType.CLOSE.equals(frame.getType())) { + handleCloseFrame(frame); + } + else if (SockJsFrameType.HEARTBEAT.equals(frame.getType())) { + if (logger.isTraceEnabled()) { + logger.trace("Received heartbeat in " + this); + } + } + else { + // should never happen + throw new IllegalStateException("Unknown SockJS frame type " + frame + " in " + this); + } + } + + private void handleOpenFrame() { + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS open frame in " + this); + } + if (State.NEW.equals(state)) { + this.state = State.OPEN; + try { + this.webSocketHandler.afterConnectionEstablished(this); + this.connectFuture.set(this); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionEstablished threw exception in " + this, ex); + } + } + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Open frame received in " + getId() + " but we're not" + + "connecting (current state=" + this.state + "). The server might " + + "have been restarted and lost track of the session."); + } + closeInternal(new CloseStatus(1006, "Server lost session")); + } + } + + private void handleMessageFrame(SockJsFrame frame) { + if (!isOpen()) { + if (logger.isWarnEnabled()) { + logger.warn("Ignoring received message due to state=" + this.state + " in " + this); + } + return; + } + String[] messages; + try { + messages = getMessageCodec().decode(frame.getFrameData()); + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex); + } + closeInternal(CloseStatus.BAD_DATA); + return; + } + if (logger.isTraceEnabled()) { + logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this); + } + for (String message : messages) { + try { + if (isOpen()) { + this.webSocketHandler.handleMessage(this, new TextMessage(message)); + } + } + catch (Throwable ex) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".handleMessage threw an exception on " + frame + " in " + this, ex); + } + } + } + + private void handleCloseFrame(SockJsFrame frame) { + CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE; + try { + String[] data = getMessageCodec().decode(frame.getFrameData()); + if (data.length == 2) { + closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]); + } + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS close frame with " + closeStatus + " in " + this); + } + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for " + frame + " in " + this, ex); + } + } + closeInternal(closeStatus); + } + + public void handleTransportError(Throwable error) { + try { + if (logger.isErrorEnabled()) { + logger.error("Transport error in " + this, error); + } + this.webSocketHandler.handleTransportError(this, error); + } + catch (Exception ex) { + Class type = this.webSocketHandler.getClass(); + if (logger.isErrorEnabled()) { + logger.error(type + ".handleTransportError threw an exception", ex); + } + } + } + + public void afterTransportClosed(CloseStatus closeStatus) { + this.closeStatus = (this.closeStatus != null ? this.closeStatus : closeStatus); + Assert.state(this.closeStatus != null, "CloseStatus not available"); + + if (logger.isInfoEnabled()) { + logger.info("Transport closed with " + this.closeStatus + " in " + this); + } + + this.state = State.CLOSED; + try { + this.webSocketHandler.afterConnectionClosed(this, this.closeStatus); + } + catch (Exception ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionClosed threw an exception", ex); + } + } + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]"; + } + + + private enum State { NEW, OPEN, CLOSING, CLOSED } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java new file mode 100644 index 0000000000..9ef2218528 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.net.URI; + +/** + * Abstract base class for XHR transport implementations to extend. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractXhrTransport implements XhrTransport { + + protected static final String PRELUDE; + + static { + byte[] bytes = new byte[2048]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = 'h'; + } + PRELUDE = new String(bytes, SockJsFrame.CHARSET); + } + + protected Log logger = LogFactory.getLog(getClass()); + + private boolean xhrStreamingDisabled; + + private HttpHeaders requestHeaders = new HttpHeaders(); + + private HttpHeaders xhrSendRequestHeaders = new HttpHeaders(); + + + /** + * Whether to attempt to connect with "xhr_streaming" first before trying + * with "xhr" next, see {@link XhrTransport#isXhrStreamingDisabled()}. + * + *

By default this property is set to {@code false} which means both + * "xhr_streaming" and "xhr" will be tried. + */ + public void setXhrStreamingDisabled(boolean disabled) { + this.xhrStreamingDisabled = disabled; + } + + public boolean isXhrStreamingDisabled() { + return this.xhrStreamingDisabled; + } + + /** + * Configure headers to be added to every executed HTTP request. + * @param requestHeaders the headers to add to requests + */ + public void setRequestHeaders(HttpHeaders requestHeaders) { + this.requestHeaders.clear(); + this.xhrSendRequestHeaders.clear(); + if (requestHeaders != null) { + this.requestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON); + } + } + + public HttpHeaders getRequestHeaders() { + return this.requestHeaders; + } + + @Override + public String executeInfoRequest(URI infoUrl) { + if (logger.isDebugEnabled()) { + logger.debug("Executing SockJS Info request, url=" + infoUrl); + } + ResponseEntity response = executeInfoRequestInternal(infoUrl); + if (response.getStatusCode() != HttpStatus.OK) { + if (logger.isErrorEnabled()) { + logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("SockJS Info request (url=" + infoUrl + ") response: " + response); + } + return response.getBody(); + } + + protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl); + + @Override + public void executeSendRequest(URI url, TextMessage message) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR send, url=" + url); + } + ResponseEntity response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message); + if (response.getStatusCode() != HttpStatus.NO_CONTENT) { + if (logger.isErrorEnabled()) { + logger.error("XHR send request (url=" + url + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR send request (url=" + url + ") response: " + response); + } + } + + protected abstract ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message); + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + SettableListenableFuture connectFuture = new SettableListenableFuture(); + XhrClientSockJsSession session = new XhrClientSockJsSession(request, handler, this, connectFuture); + request.addTimeoutTask(session.getTimeoutTask()); + + URI receiveUrl = request.getTransportUrl(); + if (logger.isDebugEnabled()) { + logger.debug("Opening XHR session, receive url=" + receiveUrl); + } + + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.putAll(request.getHandshakeHeaders()); + handshakeHeaders.putAll(getRequestHeaders()); + + connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture); + return connectFuture; + } + + protected abstract void connectInternal(TransportRequest request, WebSocketHandler handler, + URI receiveUrl, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture); + + + @Override + public String toString() { + return getClass().getSimpleName(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java new file mode 100644 index 0000000000..fd1e630f77 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A default implementation of + * {@link org.springframework.web.socket.sockjs.client.TransportRequest + * TransportRequest}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +class DefaultTransportRequest implements TransportRequest { + + private static Log logger = LogFactory.getLog(DefaultTransportRequest.class); + + + private final SockJsUrlInfo sockJsUrlInfo; + + private final HttpHeaders handshakeHeaders; + + private final Transport transport; + + private final TransportType serverTransportType; + + private SockJsMessageCodec codec; + + private Principal user; + + private long timeoutValue; + + private TaskScheduler timeoutScheduler; + + private final List timeoutTasks = new ArrayList(); + + private DefaultTransportRequest fallbackRequest; + + + public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders, + Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) { + + Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required"); + Assert.notNull(transport, "'transport' is required"); + Assert.notNull(serverTransportType, "'transportType' is required"); + Assert.notNull(codec, "'codec' is required"); + this.sockJsUrlInfo = sockJsUrlInfo; + this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders()); + this.transport = transport; + this.serverTransportType = serverTransportType; + this.codec = codec; + } + + + @Override + public SockJsUrlInfo getSockJsUrlInfo() { + return this.sockJsUrlInfo; + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.handshakeHeaders; + } + + @Override + public URI getTransportUrl() { + return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType); + } + + public void setUser(Principal user) { + this.user = user; + } + + @Override + public Principal getUser() { + return this.user; + } + + @Override + public SockJsMessageCodec getMessageCodec() { + return this.codec; + } + + public void setTimeoutValue(long timeoutValue) { + this.timeoutValue = timeoutValue; + } + + public void setTimeoutScheduler(TaskScheduler scheduler) { + this.timeoutScheduler = scheduler; + } + + @Override + public void addTimeoutTask(Runnable runnable) { + this.timeoutTasks.add(runnable); + } + + public void setFallbackRequest(DefaultTransportRequest fallbackRequest) { + this.fallbackRequest = fallbackRequest; + } + + + public void connect(WebSocketHandler handler, SettableListenableFuture future) { + if (logger.isDebugEnabled()) { + logger.debug("Starting " + this); + } + ConnectCallback connectCallback = new ConnectCallback(handler, future); + scheduleConnectTimeoutTask(connectCallback); + this.transport.connect(this, handler).addCallback(connectCallback); + } + + + private void scheduleConnectTimeoutTask(ConnectCallback connectHandler) { + if (this.timeoutScheduler != null) { + if (logger.isDebugEnabled()) { + logger.debug("Scheduling connect to time out after " + this.timeoutValue + " milliseconds"); + } + Date timeoutDate = new Date(System.currentTimeMillis() + this.timeoutValue); + this.timeoutScheduler.schedule(connectHandler, timeoutDate); + } + else if (logger.isDebugEnabled()) { + logger.debug("Connect timeout task not scheduled. Is SockJsClient configured with a TaskScheduler?"); + } + } + + + @Override + public String toString() { + return "TransportRequest[url=" + getTransportUrl() + "]"; + } + + + /** + * Updates the given (global) future based success or failure to connect for + * the entire SockJS request regardless of which transport actually managed + * to connect. Also implements {@code Runnable} to handle a scheduled timeout + * callback. + */ + private class ConnectCallback implements ListenableFutureCallback, Runnable { + + private final WebSocketHandler handler; + + private final SettableListenableFuture future; + + private final AtomicBoolean handled = new AtomicBoolean(false); + + + public ConnectCallback(WebSocketHandler handler, SettableListenableFuture future) { + this.handler = handler; + this.future = future; + } + + + @Override + public void onSuccess(WebSocketSession session) { + if (this.handled.compareAndSet(false, true)) { + this.future.set(session); + } + else { + logger.error("Connect success/failure already handled for " + DefaultTransportRequest.this); + } + } + + @Override + public void onFailure(Throwable failure) { + handleFailure(failure, false); + } + + @Override + public void run() { + handleFailure(null, true); + } + + private void handleFailure(Throwable failure, boolean isTimeoutFailure) { + if (this.handled.compareAndSet(false, true)) { + if (isTimeoutFailure) { + String message = "Connect timed out for " + DefaultTransportRequest.this; + logger.error(message); + failure = new SockJsTransportFailureException(message, getSockJsUrlInfo().getSessionId(), null); + } + if (fallbackRequest != null) { + logger.error(DefaultTransportRequest.this + " failed. Falling back on next transport.", failure); + fallbackRequest.connect(this.handler, this.future); + } + else { + logger.error("No more fallback transports after " + DefaultTransportRequest.this, failure); + this.future.setException(failure); + } + if (isTimeoutFailure) { + try { + for (Runnable runnable : timeoutTasks) { + runnable.run(); + } + } + catch (Throwable ex) { + logger.error("Transport failed to run timeout tasks for " + DefaultTransportRequest.this, ex); + } + } + } + else { + logger.error("Connect success/failure events already took place for " + + DefaultTransportRequest.this + ". Ignoring this additional failure event.", failure); + } + } + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java new file mode 100644 index 0000000000..ae8ba7bf01 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java @@ -0,0 +1,24 @@ +package org.springframework.web.socket.sockjs.client; + +import java.net.URI; + +/** + * A simple contract for executing the SockJS "Info" request before the SockJS + * session starts. The request is used to check server capabilities such as + * whether it permits use of the WebSocket transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface InfoReceiver { + + /** + * Perform an HTTP request to the SockJS "Info" URL. + * and return the resulting JSON response content, or raise an exception. + * + * @param infoUrl the URL to obtain SockJS server information from + * @return the body of the response + */ + String executeInfoRequest(URI infoUrl); + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java new file mode 100644 index 0000000000..a4f8e23870 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.client.util.StringContentProvider; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpMethod; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsException; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Enumeration; + + +/** + * An XHR transport based on Jetty's {@link org.eclipse.jetty.client.HttpClient}. + * + *

When used for testing purposes (e.g. load testing) the {@code HttpClient} + * properties must be set to allow a larger than usual number of connections and + * threads. For example: + * + *

+ * HttpClient httpClient = new HttpClient();
+ * httpClient.setMaxConnectionsPerDestination(1000);
+ * httpClient.setExecutor(new QueuedThreadPool(500));
+ * 
+ * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final HttpClient httpClient; + + + public JettyXhrTransport(HttpClient httpClient) { + Assert.notNull(httpClient, "'httpClient' is required"); + this.httpClient = httpClient; + } + + + public HttpClient getHttpClient() { + return this.httpClient; + } + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + return executeRequest(url, HttpMethod.POST, headers, message.getPayload()); + } + + protected ResponseEntity executeRequest(URI url, HttpMethod method, HttpHeaders headers, String body) { + Request httpRequest = this.httpClient.newRequest(url).method(method); + addHttpHeaders(httpRequest, headers); + if (body != null) { + httpRequest.content(new StringContentProvider(body)); + } + ContentResponse response; + try { + response = httpRequest.send(); + } + catch (Exception ex) { + throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex); + } + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + HttpHeaders responseHeaders = toHttpHeaders(response.getHeaders()); + return (response.getContent() != null ? + new ResponseEntity(response.getContentAsString(), responseHeaders, status) : + new ResponseEntity(responseHeaders, status)); + } + + private static void addHttpHeaders(Request request, HttpHeaders headers) { + for (String name : headers.keySet()) { + for (String value : headers.get(name)) { + request.header(name, value); + } + } + } + + private static HttpHeaders toHttpHeaders(HttpFields httpFields) { + HttpHeaders responseHeaders = new HttpHeaders(); + Enumeration names = httpFields.getFieldNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + Enumeration values = httpFields.getValues(name); + while (values.hasMoreElements()) { + String value = values.nextElement(); + responseHeaders.add(name, value); + } + } + return responseHeaders; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, + URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture); + executeReceiveRequest(url, handshakeHeaders, listener); + } + + private void executeReceiveRequest(URI url, HttpHeaders headers, SockJsResponseListener listener) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + url); + } + Request httpRequest = this.httpClient.newRequest(url).method(HttpMethod.POST); + addHttpHeaders(httpRequest, headers); + httpRequest.send(listener); + } + + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class SockJsResponseListener extends Response.Listener.Adapter { + + private final URI transportUrl; + + private final HttpHeaders receiveHeaders; + + private final XhrClientSockJsSession sockJsSession; + + private final SettableListenableFuture connectFuture; + + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + + public SockJsResponseListener(URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession, + SettableListenableFuture connectFuture) { + + this.transportUrl = url; + this.receiveHeaders = headers; + this.connectFuture = connectFuture; + this.sockJsSession = sockJsSession; + } + + + @Override + public void onBegin(Response response) { + if (response.getStatus() != 200) { + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + response.abort(new HttpServerErrorException(status, "Unexpected XHR receive status")); + } + } + + @Override + public void onHeaders(Response response) { + if (logger.isDebugEnabled()) { + // Convert to HttpHeaders to avoid "\n" + logger.debug("XHR receive headers: " + toHttpHeaders(response.getHeaders())); + } + } + + @Override + public void onContent(Response response, ByteBuffer buffer) { + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.abort(new SockJsException("Session closed.", this.sockJsSession.getId(), null)); + return; + } + if (buffer.remaining() == 0) { + break; + } + int b = buffer.get(); + if (b == '\n') { + handleFrame(); + } + else { + this.outputStream.write(b); + } + } + } + + private void handleFrame() { + byte[] bytes = this.outputStream.toByteArray(); + this.outputStream.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR content received: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + + @Override + public void onSuccess(Response response) { + if (this.outputStream.size() > 0) { + handleFrame(); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive request completed."); + } + executeReceiveRequest(this.transportUrl, this.receiveHeaders, this); + } + + @Override + public void onFailure(Response response, Throwable failure) { + if (connectFuture.setException(failure)) { + return; + } + if (this.sockJsSession.isDisconnected()) { + this.sockJsSession.afterTransportClosed(null); + } + else { + this.sockJsSession.handleTransportError(failure); + this.sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage())); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java new file mode 100644 index 0000000000..1f25be4009 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; + +/** + * An {@code XhrTransport} implementation that uses a + * {@link org.springframework.web.client.RestTemplate RestTemplate}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class RestTemplateXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final RestOperations restTemplate; + + private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); + + + public RestTemplateXhrTransport() { + this(new RestTemplate()); + } + + public RestTemplateXhrTransport(RestOperations restTemplate) { + Assert.notNull(restTemplate, "'restTemplate' is required"); + this.restTemplate = restTemplate; + } + + + /** + * Return the configured {@code RestTemplate}. + */ + public RestOperations getRestTemplate() { + return this.restTemplate; + } + + /** + * Configure the {@code TaskExecutor} to use to execute XHR receive requests. + * + *

By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor + * SimpleAsyncTaskExecutor} is configured which creates a new thread every + * time the transports connects. + * + * @param taskExecutor the task executor, cannot be {@code null} + */ + public void setTaskExecutor(TaskExecutor taskExecutor) { + Assert.notNull(this.taskExecutor); + this.taskExecutor = taskExecutor; + } + + /** + * Return the configured {@code TaskExecutor}. + */ + public TaskExecutor getTaskExecutor() { + return this.taskExecutor; + } + + + @Override + public ResponseEntity executeInfoRequestInternal(URI infoUrl) { + RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders()); + return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textExtractor); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload()); + return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textExtractor); + } + + @Override + protected void connectInternal(final TransportRequest request, final WebSocketHandler handler, + final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, + final SettableListenableFuture connectFuture) { + + getTaskExecutor().execute(new Runnable() { + @Override + public void run() { + XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); + XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders()); + XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); + while (true) { + if (session.isDisconnected()) { + session.afterTransportClosed(null); + break; + } + try { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + receiveUrl); + } + getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor); + requestCallback = requestCallbackAfterHandshake; + } + catch (Throwable ex) { + if (!connectFuture.isDone()) { + connectFuture.setException(ex); + } + else { + session.handleTransportError(ex); + session.afterTransportClosed(new CloseStatus(1006, ex.getMessage())); + } + break; + } + } + } + }); + } + + + /** + * A RequestCallback to add the headers and (optionally) String content. + */ + private static class XhrRequestCallback implements RequestCallback { + + private final HttpHeaders headers; + + private final String body; + + + public XhrRequestCallback(HttpHeaders headers) { + this(headers, null); + } + + public XhrRequestCallback(HttpHeaders headers, String body) { + this.headers = headers; + this.body = body; + } + + + @Override + public void doWithRequest(ClientHttpRequest request) throws IOException { + if (this.headers != null) { + request.getHeaders().putAll(this.headers); + } + if (this.body != null) { + StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody()); + } + } + } + + /** + * A simple ResponseExtractor that reads the body into a String. + */ + private final static ResponseExtractor> textExtractor = + new ResponseExtractor>() { + + @Override + public ResponseEntity extractData(ClientHttpResponse response) throws IOException { + if (response.getBody() == null) { + return new ResponseEntity(response.getHeaders(), response.getStatusCode()); + } + else { + String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); + return new ResponseEntity(body, response.getHeaders(), response.getStatusCode()); + } + } + }; + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class XhrReceiveExtractor implements ResponseExtractor { + + private final XhrClientSockJsSession sockJsSession; + + + public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) { + this.sockJsSession = sockJsSession; + } + + + @Override + public Object extractData(ClientHttpResponse response) throws IOException { + if (!HttpStatus.OK.equals(response.getStatusCode())) { + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive headers: " + response.getHeaders()); + } + InputStream is = response.getBody(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.close(); + break; + } + int b = is.read(); + if (b == -1) { + if (os.size() > 0) { + handleFrame(os); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive completed"); + } + break; + } + if (b == '\n') { + handleFrame(os); + } + else { + os.write(b); + } + } + return null; + } + + private void handleFrame(ByteArrayOutputStream os) { + byte[] bytes = os.toByteArray(); + os.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR receive content: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + } + +} + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java new file mode 100644 index 0000000000..5d8bef8cfe --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A SockJS implementation of + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} + * with HTTP-based fallback alternative simulating a WebSocket interaction. + * + * @author Rossen Stoyanchev + * @since 4.1 + * + * @see http://sockjs.org + * @see org.springframework.web.socket.sockjs.client.Transport + */ +public class SockJsClient extends AbstractWebSocketClient { + + private static final boolean jackson2Present = ClassUtils.isPresent( + "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader()); + + private static final Log logger = LogFactory.getLog(SockJsClient.class); + + + private final List transports; + + private InfoReceiver infoReceiver; + + private SockJsMessageCodec messageCodec; + + private TaskScheduler taskScheduler; + + private final Map infoCache = new ConcurrentHashMap(); + + + /** + * Create a {@code SockJsClient} with the given transports. + * @param transports the transports to use + */ + public SockJsClient(List transports) { + Assert.notEmpty(transports, "No transports provided"); + this.transports = new ArrayList(transports); + this.infoReceiver = initInfoReceiver(transports); + if (jackson2Present) { + this.messageCodec = new Jackson2SockJsMessageCodec(); + } + } + + private static InfoReceiver initInfoReceiver(List transports) { + for (Transport transport : transports) { + if (transport instanceof InfoReceiver) { + return ((InfoReceiver) transport); + } + } + return new RestTemplateXhrTransport(); + } + + + /** + * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" + * request before the SockJS session starts. + * + *

By default this is initialized either by looking through the configured + * transports to find the first {@code XhrTransport} or by creating an + * instance of {@code RestTemplateXhrTransport}. + * + * @param infoReceiver the transport to use for the SockJS "Info" request + */ + public void setInfoReceiver(InfoReceiver infoReceiver) { + this.infoReceiver = infoReceiver; + } + + public InfoReceiver getInfoReceiver() { + return this.infoReceiver; + } + + /** + * Set the SockJsMessageCodec to use. + * + *

By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec + * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath. + * + * @param messageCodec the message messageCodec to use + */ + public void setMessageCodec(SockJsMessageCodec messageCodec) { + Assert.notNull(messageCodec, "'messageCodec' is required"); + this.messageCodec = messageCodec; + } + + public SockJsMessageCodec getMessageCodec() { + return this.messageCodec; + } + + /** + * Configure a {@code TaskScheduler} for scheduling a connect timeout task + * where the timeout value is calculated based on the duration of the initial + * SockJS info request. Having a connect timeout task is optional but can + * improve the speed with which the client falls back to alternative + * transport options. + * + *

By default no task scheduler is configured in which case it may take + * longer before a fallback transport can be used. + * + * @param taskScheduler the scheduler to use + */ + public void setTaskScheduler(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + public void clearServerInfoCache() { + this.infoCache.clear(); + } + + @Override + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme) + || "http".equals(scheme) || "https".equals(scheme)), "Invalid scheme: " + scheme); + } + + @Override + protected ListenableFuture doHandshakeInternal(WebSocketHandler handler, + HttpHeaders handshakeHeaders, URI url, List protocols, + List extensions, Map attributes) { + + SettableListenableFuture connectFuture = new SettableListenableFuture(); + try { + SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); + ServerInfo serverInfo = getServerInfo(sockJsUrlInfo); + createFallbackChain(sockJsUrlInfo, handshakeHeaders, serverInfo).connect(handler, connectFuture); + } + catch (Throwable exception) { + if (logger.isErrorEnabled()) { + logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception); + } + connectFuture.setException(exception); + } + return connectFuture; + } + + private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) { + URI infoUrl = sockJsUrlInfo.getInfoUrl(); + ServerInfo info = this.infoCache.get(infoUrl); + if (info == null) { + long start = System.currentTimeMillis(); + String response = this.infoReceiver.executeInfoRequest(infoUrl); + long infoRequestTime = System.currentTimeMillis() - start; + info = new ServerInfo(response, infoRequestTime); + this.infoCache.put(infoUrl, info); + } + return info; + } + + private DefaultTransportRequest createFallbackChain(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) { + List requests = new ArrayList(this.transports.size()); + for (Transport transport : this.transports) { + if (transport instanceof XhrTransport) { + XhrTransport xhrTransport = (XhrTransport) transport; + if (!xhrTransport.isXhrStreamingDisabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR_STREAMING); + } + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR); + } + else if (serverInfo.isWebSocketEnabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.WEBSOCKET); + } + } + Assert.notEmpty(requests, + "0 transports for request to " + urlInfo + " . Configured transports: " + + this.transports + ". SockJS server webSocketEnabled=" + serverInfo.isWebSocketEnabled()); + for (int i = 0; i < requests.size() - 1; i++) { + requests.get(i).setFallbackRequest(requests.get(i + 1)); + } + return requests.get(0); + } + + private void addRequest(List requests, SockJsUrlInfo info, HttpHeaders headers, + ServerInfo serverInfo, Transport transport, TransportType type) { + + DefaultTransportRequest request = new DefaultTransportRequest(info, headers, transport, type, getMessageCodec()); + request.setUser(getUser()); + if (this.taskScheduler != null) { + request.setTimeoutValue(serverInfo.getRetransmissionTimeout()); + request.setTimeoutScheduler(this.taskScheduler); + } + requests.add(request); + } + + /** + * Return the user to associate with the SockJS session and make available via + * {@link org.springframework.web.socket.WebSocketSession#getPrincipal() + * WebSocketSession#getPrincipal()}. + *

By default this method returns {@code null}. + * @return the user to associate with the session, possibly {@code null} + */ + protected Principal getUser() { + return null; + } + + + private static class ServerInfo { + + private final boolean webSocketEnabled; + + private final long responseTime; + + + private ServerInfo(String response, long responseTime) { + this.responseTime = responseTime; + this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*"); + } + + public boolean isWebSocketEnabled() { + return this.webSocketEnabled; + } + + public long getRetransmissionTimeout() { + return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java new file mode 100644 index 0000000000..6530f0a62b --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.AlternativeJdkIdGenerator; +import org.springframework.util.IdGenerator; +import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; +import java.util.UUID; + +/** + * Given the base URL to a SockJS server endpoint, also provides methods to + * generate and obtain session and a server id used for construct a transport URL. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SockJsUrlInfo { + + private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator(); + + + private final URI sockJsUrl; + + private String serverId; + + private String sessionId; + + private UUID uuid; + + + public SockJsUrlInfo(URI sockJsUrl) { + this.sockJsUrl = sockJsUrl; + } + + + public URI getSockJsUrl() { + return this.sockJsUrl; + } + + public String getServerId() { + if (this.serverId == null) { + this.serverId = String.valueOf(Math.abs(getUuid().getMostSignificantBits()) % 1000); + } + return this.serverId; + } + + public String getSessionId() { + if (this.sessionId == null) { + this.sessionId = getUuid().toString().replace("-",""); + } + return this.sessionId; + } + + protected UUID getUuid() { + if (this.uuid == null) { + this.uuid = idGenerator.generateId(); + } + return this.uuid; + } + + public URI getInfoUrl() { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(TransportType.XHR)) + .pathSegment("info") + .build(true).toUri(); + } + + public URI getTransportUrl(TransportType transportType) { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(transportType)) + .pathSegment(getServerId()) + .pathSegment(getSessionId()) + .pathSegment(transportType.toString()) + .build(true).toUri(); + } + + private String getScheme(TransportType transportType) { + String scheme = this.sockJsUrl.getScheme(); + if (TransportType.WEBSOCKET.equals(transportType)) { + if (!scheme.startsWith("ws")) { + scheme = ("https".equals(scheme) ? "wss" : "ws"); + } + } + else { + if (!scheme.startsWith("http")) { + scheme = ("wss".equals(scheme) ? "https" : "http"); + } + } + return scheme; + } + + + @Override + public String toString() { + return "SockJsUrlInfo[url=" + this.sockJsUrl + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java new file mode 100644 index 0000000000..41d554a5ef --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +/** + * A client-side implementation for a SockJS transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface Transport { + + /** + * Connect the transport. + * + * @param request the transport request. + * @param webSocketHandler the application handler to delegate lifecycle events to. + * @return a future to indicate success or failure to connect. + */ + ListenableFuture connect(TransportRequest request, WebSocketHandler webSocketHandler); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java new file mode 100644 index 0000000000..5e92b28296 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java @@ -0,0 +1,53 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; + +/** + * Represents a request to connect to a SockJS service using a specific + * Transport. A single SockJS request however may require falling back + * and therefore multiple TransportRequest instances. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface TransportRequest { + + /** + * Return information about the SockJS URL including server and session id.. + */ + SockJsUrlInfo getSockJsUrlInfo(); + + /** + * Return the headers to send with the connect request. + */ + HttpHeaders getHandshakeHeaders(); + + /** + * Return the transport URL for the given transport. + * For an {@link XhrTransport} this is the URL for receiving messages. + */ + URI getTransportUrl(); + + /** + * Return the user associated with the request, if any. + */ + Principal getUser(); + + /** + * Return the message codec to use for encoding SockJS messages. + */ + SockJsMessageCodec getMessageCodec(); + + /** + * Register a timeout cleanup task to invoke if the SockJS session is not + * fully established within the calculated retransmission timeout period. + */ + void addTimeoutTask(Runnable runnable); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java new file mode 100644 index 0000000000..543a6a2c73 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.NativeWebSocketSession; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.List; + +/** + * An extension of {@link AbstractClientSockJsSession} wrapping and delegating + * to an actual WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketClientSockJsSession extends AbstractClientSockJsSession implements NativeWebSocketSession { + + private WebSocketSession webSocketSession; + + + public WebSocketClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + } + + + @Override + public Object getNativeSession() { + return this.webSocketSession; + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeSession(Class requiredType) { + if (requiredType != null) { + if (requiredType.isInstance(this.webSocketSession)) { + return (T) this.webSocketSession; + } + } + return null; + } + + @Override + public InetSocketAddress getLocalAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getRemoteAddress(); + } + + @Override + public String getAcceptedProtocol() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getAcceptedProtocol(); + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setTextMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getTextMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getTextMessageSizeLimit(); + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setBinaryMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getBinaryMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getBinaryMessageSizeLimit(); + } + + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getExtensions(); + } + + private void checkDelegateSessionInitialized() { + Assert.state(this.webSocketSession != null, "WebSocketSession not yet initialized"); + } + + public void initializeDelegateSession(WebSocketSession session) { + this.webSocketSession = session; + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.webSocketSession.sendMessage(textMessage); + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + if (this.webSocketSession != null) { + this.webSocketSession.close(status); + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java new file mode 100644 index 0000000000..ef36a96840 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +import java.net.URI; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A SockJS {@link Transport} that uses a + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketTransport implements Transport { + + private static Log logger = LogFactory.getLog(WebSocketTransport.class); + + private final WebSocketClient webSocketClient; + + + public WebSocketTransport(WebSocketClient webSocketClient) { + Assert.notNull(webSocketClient, "'webSocketClient' is required"); + this.webSocketClient = webSocketClient; + } + + + /** + * Return the configured {@code WebSocketClient}. + */ + public WebSocketClient getWebSocketClient() { + return this.webSocketClient; + } + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + final SettableListenableFuture future = new SettableListenableFuture(); + WebSocketClientSockJsSession session = new WebSocketClientSockJsSession(request, handler, future); + handler = new ClientSockJsWebSocketHandler(session); + request.addTimeoutTask(session.getTimeoutTask()); + + URI url = request.getTransportUrl(); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHandshakeHeaders()); + if (logger.isDebugEnabled()) { + logger.debug("Opening WebSocket connection, url=" + url); + } + this.webSocketClient.doHandshake(handler, headers, url).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession webSocketSession) { + // WebSocket session ready, SockJS Session not yet + } + @Override + public void onFailure(Throwable t) { + future.setException(t); + } + }); + return future; + } + + @Override + public String toString() { + return "WebSocketTransport[client=" + this.webSocketClient + "]"; + } + + + private static class ClientSockJsWebSocketHandler extends TextWebSocketHandler { + + private final WebSocketClientSockJsSession sockJsSession; + + private final AtomicInteger connectCount = new AtomicInteger(0); + + + private ClientSockJsWebSocketHandler(WebSocketClientSockJsSession session) { + Assert.notNull(session); + this.sockJsSession = session; + } + + @Override + public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception { + Assert.isTrue(this.connectCount.compareAndSet(0, 1)); + this.sockJsSession.initializeDelegateSession(webSocketSession); + } + + @Override + public void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception { + this.sockJsSession.handleFrame(message.getPayload()); + } + + @Override + public void handleTransportError(WebSocketSession webSocketSession, Throwable ex) throws Exception { + this.sockJsSession.handleTransportError(ex); + } + + @Override + public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception { + this.sockJsSession.afterTransportClosed(status); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java new file mode 100644 index 0000000000..92de51cea7 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + + +/** + * An extension of {@link AbstractClientSockJsSession} for use with HTTP + * transports simulating a WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class XhrClientSockJsSession extends AbstractClientSockJsSession { + + private final URI sendUrl; + + private final XhrTransport transport; + + private int textMessageSizeLimit = -1; + + private int binaryMessageSizeLimit = -1; + + + public XhrClientSockJsSession(TransportRequest request, WebSocketHandler handler, + XhrTransport transport, SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + Assert.notNull(transport, "'restTemplate' is required"); + this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); + this.transport = transport; + } + + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(getUri().getHost(), getUri().getPort()); + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + this.textMessageSizeLimit = messageSizeLimit; + } + + @Override + public int getTextMessageSizeLimit() { + return this.textMessageSizeLimit; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + this.binaryMessageSizeLimit = -1; + } + + @Override + public int getBinaryMessageSizeLimit() { + return this.binaryMessageSizeLimit; + } + + @Override + public List getExtensions() { + return null; + } + + @Override + protected void sendInternal(TextMessage message) { + this.transport.executeSendRequest(this.sendUrl, message); + } + + @Override + protected void disconnect(CloseStatus status) { + // Nothing to do, XHR transports check if session is disconnected + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java new file mode 100644 index 0000000000..726b202464 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java @@ -0,0 +1,40 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; + +import java.net.URI; + +/** + * A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket + * interaction. The {@code connect} method of the base {@code Transport} interface + * is used to receive messages from the server while the + * {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage) + * executeSendRequest(URI, TextMessage)} method here is used to send messages. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface XhrTransport extends Transport, InfoReceiver { + + /** + * An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS + * server transports. From a client perspective there is no implementation + * difference. + * + *

By default an {@code XhrTransport} will be used with "xhr_streaming" + * first and then with "xhr", if the streaming fails to connect. In some + * cases it may be useful to suppress streaming so that only "xhr" is used. + */ + boolean isXhrStreamingDisabled(); + + /** + * Execute a request to send the message to the server. + * @param transportUrl the URL for sending messages. + * @param message the message to send + */ + void executeSendRequest(URI transportUrl, TextMessage message); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java new file mode 100644 index 0000000000..6ec13fd08e --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * SockJS client implementation of + * {@link org.springframework.web.socket.client.WebSocketClient}. + */ +package org.springframework.web.socket.sockjs.client; + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java index 7e07bf7ced..99888fcdd5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java @@ -145,7 +145,7 @@ public class SockJsFrame { if (!(other instanceof SockJsFrame)) { return false; } - return this.content.equals(((SockJsFrame) other).content); + return (this.type.equals(((SockJsFrame) other).type) && this.content.equals(((SockJsFrame) other).content)); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index d9a06db592..dc8d020b74 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -238,7 +238,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem } else { response.setStatusCode(HttpStatus.NOT_FOUND); - logger.warn("Session not found"); + logger.warn("Session not found, sessionId=" + sessionId); return; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java index 9bea3bcd76..f4b2fd3f91 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java @@ -78,6 +78,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { @Override public void stop() throws Exception { if (this.jettyServer.isRunning()) { + this.jettyServer.setStopTimeout(0); this.jettyServer.stop(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java new file mode 100644 index 0000000000..016b7ef924 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -0,0 +1,394 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.WebSocketTestServer; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.support.DefaultHandshakeHandler; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests using the + * {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * against actual SockJS server endpoints. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractSockJsIntegrationTests { + + protected Log logger = LogFactory.getLog(getClass()); + + private WebSocketTestServer server; + + private AnnotationConfigWebApplicationContext wac; + + private ErrorFilter errorFilter; + + private String baseUrl; + + + @Before + public void setup() throws Exception { + this.errorFilter = new ErrorFilter(); + this.wac = new AnnotationConfigWebApplicationContext(); + this.wac.register(TestConfig.class, upgradeStrategyConfigClass()); + this.wac.refresh(); + this.server = createWebSocketTestServer(); + this.server.deployConfig(this.wac, this.errorFilter); + this.server.start(); + this.baseUrl = "http://localhost:" + this.server.getPort(); + } + + @After + public void teardown() throws Exception { + try { + this.server.undeployConfig(); + } + catch (Throwable t) { + logger.error("Failed to undeploy application config", t); + } + try { + this.server.stop(); + } + catch (Throwable t) { + logger.error("Failed to stop server", t); + } + } + + protected abstract WebSocketTestServer createWebSocketTestServer(); + + protected abstract Class upgradeStrategyConfigClass(); + + protected abstract Transport getWebSocketTransport(); + + protected abstract AbstractXhrTransport getXhrTransport(); + + protected SockJsClient createSockJsClient(Transport... transports) { + return new SockJsClient(Arrays.asList(transports)); + } + + @Test + public void echoWebSocket() throws Exception { + testEcho(100, getWebSocketTransport()); + } + + @Test + public void echoXhrStreaming() throws Exception { + testEcho(100, getXhrTransport()); + } + + @Test + public void echoXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testEcho(100, xhrTransport); + } + + @Test + public void closeAfterOneMessageWebSocket() throws Exception { + testCloseAfterOneMessage(getWebSocketTransport()); + } + + @Test + public void closeAfterOneMessageXhrStreaming() throws Exception { + testCloseAfterOneMessage(getXhrTransport()); + } + + @Test + public void closeAfterOneMessageXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testCloseAfterOneMessage(xhrTransport); + } + + @Test + public void infoRequestFailure() throws Exception { + TestClientHandler handler = new TestClientHandler(); + this.errorFilter.responseStatusMap.put("/info", 500); + CountDownLatch latch = new CountDownLatch(1); + createSockJsClient(getWebSocketTransport()).doHandshake(handler, this.baseUrl + "/echo").addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + + } + @Override + public void onFailure(Throwable t) { + latch.countDown(); + } + } + ); + assertTrue(latch.await(5000, TimeUnit.MILLISECONDS)); + } + + @Test + public void fallbackAfterTransportFailure() throws Exception { + this.errorFilter.responseStatusMap.put("/websocket", 200); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 500); + TestClientHandler handler = new TestClientHandler(); + Transport[] transports = { getWebSocketTransport(), getXhrTransport() }; + WebSocketSession session = createSockJsClient(transports).doHandshake(handler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass()); + TextMessage message = new TextMessage("message1"); + session.sendMessage(message); + handler.awaitMessage(message, 5000); + } + + @Test(timeout = 5000) + public void fallbackAfterConnectTimeout() throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + this.errorFilter.sleepDelayMap.put("/xhr_streaming", 10000L); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 503); + SockJsClient sockJsClient = createSockJsClient(getXhrTransport()); + sockJsClient.setTaskScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class)); + WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass()); + TextMessage message = new TextMessage("message1"); + clientSession.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + clientSession.close(); + } + + + private void testEcho(int messageCount, Transport transport) throws Exception { + List messages = new ArrayList<>(); + for (int i = 0; i < messageCount; i++) { + messages.add(new TextMessage("m" + i)); + } + TestClientHandler handler = new TestClientHandler(); + WebSocketSession session = createSockJsClient(transport).doHandshake(handler, this.baseUrl + "/echo").get(); + for (TextMessage message : messages) { + session.sendMessage(message); + } + handler.awaitMessageCount(messageCount, 5000); + for (TextMessage message : messages) { + assertTrue("Message not received: " + message, handler.receivedMessages.remove(message)); + } + assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size()); + session.close(); + } + + private void testCloseAfterOneMessage(Transport transport) throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + createSockJsClient(transport).doHandshake(clientHandler, this.baseUrl + "/test").get(); + TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class); + + assertNotNull("afterConnectionEstablished should have been called", clientHandler.session); + serverHandler.awaitSession(5000); + + TextMessage message = new TextMessage("message1"); + serverHandler.session.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + + CloseStatus expected = new CloseStatus(3500, "Oops"); + serverHandler.session.close(expected); + CloseStatus actual = clientHandler.awaitCloseStatus(5000); + if (transport instanceof XhrTransport) { + assertThat(actual, Matchers.anyOf(equalTo(expected), equalTo(new CloseStatus(3000, "Go away!")))); + } + else { + assertEquals(expected, actual); + } + } + + + @Configuration + @EnableWebSocket + static class TestConfig implements WebSocketConfigurer { + + @Autowired + private RequestUpgradeStrategy upgradeStrategy; + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy); + registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS(); + registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS(); + } + + @Bean + public TestServerHandler testServerHandler() { + return new TestServerHandler(); + } + } + + private static interface Condition { + boolean match(); + } + + private static void awaitEvent(Condition condition, long timeToWait, String description) { + long timeToSleep = 200; + for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) { + if (condition.match()) { + return; + } + try { + Thread.sleep(timeToSleep); + } + catch (InterruptedException e) { + throw new IllegalStateException("Interrupted while waiting for " + description, e); + } + } + throw new IllegalStateException("Timed out waiting for " + description); + } + + private static class TestClientHandler extends TextWebSocketHandler { + + private final BlockingQueue receivedMessages = new LinkedBlockingQueue<>(); + + private volatile WebSocketSession session; + + private volatile CloseStatus closeStatus; + + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + this.receivedMessages.add(message); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { + this.closeStatus = status; + } + + public void awaitMessageCount(final int count, long timeToWait) throws Exception { + awaitEvent(() -> receivedMessages.size() >= count, timeToWait, + count + " number of messages. Received so far: " + this.receivedMessages); + } + + public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException { + TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS); + assertNotNull("Timed out waiting for [" + expected + "]", actual); + assertEquals(expected, actual); + } + + public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus"); + return this.closeStatus; + } + } + + private static class TestServerHandler extends TextWebSocketHandler { + + private WebSocketSession session; + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + public WebSocketSession awaitSession(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.session != null, timeToWait, " session"); + return this.session; + } + } + + private static class EchoHandler extends TextWebSocketHandler { + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + session.sendMessage(message); + } + } + + private static class ErrorFilter implements Filter { + + private final Map responseStatusMap = new HashMap<>(); + + private final Map sleepDelayMap = new HashMap<>(); + + @Override + public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException, ServletException { + for (String suffix : this.sleepDelayMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + try { + Thread.sleep(this.sleepDelayMap.get(suffix)); + break; + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + for (String suffix : this.responseStatusMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + ((HttpServletResponse) resp).sendError(this.responseStatusMap.get(suffix)); + return; + } + } + chain.doFilter(req, resp); + } + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + } + + @Override + public void destroy() { + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java new file mode 100644 index 0000000000..9d3f76b460 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + +import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractClientSockJsSession}. + * + * @author Rossen Stoyanchev + */ +public class ClientSockJsSessionTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private TestClientSockJsSession session; + + private WebSocketHandler handler; + + private SettableListenableFuture connectFuture; + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + Transport transport = mock(Transport.class); + TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC); + this.handler = mock(WebSocketHandler.class); + this.connectFuture = new SettableListenableFuture<>(); + this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture); + } + + + @Test + public void handleFrameOpen() throws Exception { + assertThat(this.session.isOpen(), is(false)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + assertTrue(this.connectFuture.isDone()); + assertThat(this.connectFuture.get(), sameInstance(this.session)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameOpenWhenStatusNotNew() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1006, "Server lost session"))); + } + + @Test + public void handleFrameOpenWithWebSocketHandlerException() throws Exception { + doThrow(new IllegalStateException("Fake error")).when(this.handler).afterConnectionEstablished(this.session); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + } + + @Test + public void handleFrameMessage() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWhenNotOpen() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + reset(this.handler); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithBadData() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame("a['bad data"); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.BAD_DATA)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithWebSocketHandlerException() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("foo")); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("bar")); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + assertThat(this.session.isOpen(), equalTo(true)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameClose() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.closeFrame(1007, "").getContent()); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1007, ""))); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleTransportError() throws Exception { + final IllegalStateException ex = new IllegalStateException("Fake error"); + this.session.handleTransportError(ex); + verify(this.handler).handleTransportError(this.session, ex); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void afterTransportClosed() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.afterTransportClosed(CloseStatus.SERVER_ERROR); + assertThat(this.session.isOpen(), equalTo(false)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).afterConnectionClosed(this.session, CloseStatus.SERVER_ERROR); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void close() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.NORMAL)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void closeWithStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(new CloseStatus(3000, "reason")); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(3000, "reason"))); + } + + @Test + public void closeWithNullStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(null); + } + + @Test + public void closeWithStatusOutOfRange() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(new CloseStatus(2999, "reason")); + } + + @Test + public void timeoutTask() { + this.session.getTimeoutTask().run(); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(2007, "Transport timed out"))); + } + + @Test + public void send() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.sendMessage(new TextMessage("foo")); + assertThat(this.session.sentMessage, equalTo(new TextMessage("[\"foo\"]"))); + } + + + private static class TestClientSockJsSession extends AbstractClientSockJsSession { + + private TextMessage sentMessage; + + private CloseStatus disconnectStatus; + + + protected TestClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + super(request, handler, connectFuture); + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.sentMessage = textMessage; + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + this.disconnectStatus = status; + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return null; + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getTextMessageSizeLimit() { + return 0; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getBinaryMessageSizeLimit() { + return 0; + } + + @Override + public List getExtensions() { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java new file mode 100644 index 0000000000..6d1eeda631 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.URI; +import java.util.Date; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link DefaultTransportRequest}. + * + * @author Rossen Stoyanchev + */ +public class DefaultTransportRequestTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + + private SettableListenableFuture connectFuture; + + private ListenableFutureCallback connectCallback; + + private TestTransport webSocketTransport; + + private TestTransport xhrTransport; + + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + this.connectCallback = mock(ListenableFutureCallback.class); + this.connectFuture = new SettableListenableFuture<>(); + this.connectFuture.addCallback(this.connectCallback); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new TestTransport("XhrTestTransport"); + } + + + @Test + @SuppressWarnings("unchecked") + public void connect() throws Exception { + DefaultTransportRequest request = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + request.connect(null, this.connectFuture); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + assertSame(session, this.connectFuture.get()); + } + + @Test + public void fallbackAfterTransportError() throws Exception { + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.connect(null, this.connectFuture); + + // Transport error => fallback + this.webSocketTransport.getConnectCallback().onFailure(new IOException("Fake exception 1")); + assertFalse(this.connectFuture.isDone()); + assertTrue(this.xhrTransport.invoked()); + + // Transport error => no more fallback + this.xhrTransport.getConnectCallback().onFailure(new IOException("Fake exception 2")); + assertTrue(this.connectFuture.isDone()); + this.thrown.expect(ExecutionException.class); + this.thrown.expectMessage("Fake exception 2"); + this.connectFuture.get(); + } + + @Test + public void fallbackAfterTimeout() throws Exception { + TaskScheduler scheduler = mock(TaskScheduler.class); + Runnable sessionCleanupTask = mock(Runnable.class); + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.setTimeoutScheduler(scheduler); + request1.addTimeoutTask(sessionCleanupTask); + request1.connect(null, this.connectFuture); + + assertTrue(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + + // Get and invoke the scheduled timeout task + ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(scheduler).schedule(taskCaptor.capture(), any(Date.class)); + verifyNoMoreInteractions(scheduler); + taskCaptor.getValue().run(); + + assertTrue(this.xhrTransport.invoked()); + verify(sessionCleanupTask).run(); + } + + protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java new file mode 100644 index 0000000000..89120794ec --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.JettyWebSocketTestServer; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy; + +import java.util.ArrayList; +import java.util.List; + +/** + * SockJS integration tests using Jetty for client and server. + * + * @author Rossen Stoyanchev + */ +public class JettySockJsIntegrationTests extends AbstractSockJsIntegrationTests { + + private WebSocketClient webSocketClient; + + private HttpClient httpClient; + + + @Before + public void setup() throws Exception { + super.setup(); + this.webSocketClient = new WebSocketClient(); + this.webSocketClient.start(); + this.httpClient = new HttpClient(); + this.httpClient.start(); + } + + @After + public void teardown() throws Exception { + super.teardown(); + try { + this.webSocketClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty WebSocketClient", ex); + } + try { + this.httpClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty HttpClient", ex); + } + } + + @Override + protected JettyWebSocketTestServer createWebSocketTestServer() { + return new JettyWebSocketTestServer(); + } + + @Override + protected Class upgradeStrategyConfigClass() { + return JettyTestConfig.class; + } + + @Override + protected Transport getWebSocketTransport() { + return new WebSocketTransport(new JettyWebSocketClient(this.webSocketClient)); + } + + @Override + protected AbstractXhrTransport getXhrTransport() { + return new JettyXhrTransport(this.httpClient); + } + + + @Configuration + static class JettyTestConfig { + + @Bean + public RequestUpgradeStrategy upgradeStrategy() { + return new JettyRequestUpgradeStrategy(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java new file mode 100644 index 0000000000..f5b9318d4a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompEncoder; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link RestTemplateXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class RestTemplateXhrTransportTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private WebSocketHandler webSocketHandler; + + + @Before + public void setup() throws Exception { + this.webSocketHandler = mock(WebSocketHandler.class); + } + + + @Test + public void connectReceiveAndClose() throws Exception { + String body = "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithPrelude() throws Exception { + StringBuilder sb = new StringBuilder(2048); + for (int i=0; i < 2048; i++) { + sb.append('h'); + } + String body = sb.toString() + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithStompFrame() throws Exception { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); + accessor.setDestination("/destination"); + MessageHeaders headers = accessor.getMessageHeaders(); + Message message = MessageBuilder.createMessage("body".getBytes(Charset.forName("UTF-8")), headers); + byte[] bytes = new StompEncoder().encode(message); + TextMessage textMessage = new TextMessage(bytes); + SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload()); + + String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(textMessage)); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectFailure() throws Exception { + final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR); + RestOperations restTemplate = mock(RestOperations.class); + when(restTemplate.execute(any(), eq(HttpMethod.POST), any(), any())).thenThrow(expected); + + final CountDownLatch latch = new CountDownLatch(1); + connect(restTemplate).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + } + @Override + public void onFailure(Throwable actual) { + if (actual == expected) { + latch.countDown(); + } + } + } + ); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void errorResponseStatus() throws Exception { + connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops")); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleTransportError(any(), any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void responseClosedAfterDisconnected() throws Exception { + String body = "o\n" + "c[3000,\"Go away!\"]\n" + "a[\"foo\"]\n"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + verify(response).close(); + } + + private ListenableFuture connect(ClientHttpResponse... responses) throws Exception { + return connect(new TestRestTemplate(responses)); + } + + private ListenableFuture connect(RestOperations restTemplate, ClientHttpResponse... responses) + throws Exception { + + RestTemplateXhrTransport transport = new RestTemplateXhrTransport(restTemplate); + transport.setTaskExecutor(new SyncTaskExecutor()); + + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + HttpHeaders headers = new HttpHeaders(); + headers.add("h-foo", "h-bar"); + TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC); + + return transport.connect(request, this.webSocketHandler); + } + + private ClientHttpResponse response(HttpStatus status, String body) throws IOException { + ClientHttpResponse response = mock(ClientHttpResponse.class); + InputStream inputStream = getInputStream(body); + when(response.getStatusCode()).thenReturn(status); + when(response.getBody()).thenReturn(inputStream); + return response; + } + + private InputStream getInputStream(String content) { + byte[] bytes = content.getBytes(Charset.forName("UTF-8")); + return new ByteArrayInputStream(bytes); + } + + + + private static class TestRestTemplate extends RestTemplate { + + private Queue responses = new LinkedBlockingDeque<>(); + + + private TestRestTemplate(ClientHttpResponse... responses) { + this.responses.addAll(Arrays.asList(responses)); + } + + @Override + public T execute(URI url, HttpMethod method, RequestCallback callback, ResponseExtractor extractor) throws RestClientException { + try { + extractor.extractData(this.responses.remove()); + } + catch (Throwable t) { + throw new RestClientException("Failed to invoke extractor", t); + } + return null; + } + } + + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java new file mode 100644 index 0000000000..b304257803 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport; + +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * + * @author Rossen Stoyanchev + */ +public class SockJsClientTests { + + private static final String URL = "http://example.com"; + + private static final WebSocketHandler handler = mock(WebSocketHandler.class); + + + private SockJsClient sockJsClient; + + private InfoReceiver infoReceiver; + + private TestTransport webSocketTransport; + + private XhrTestTransport xhrTransport; + + private ListenableFutureCallback connectCallback; + + + @Before + @SuppressWarnings("unchecked") + public void setup() { + this.infoReceiver = mock(InfoReceiver.class); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new XhrTestTransport("XhrTestTransport"); + + List transports = new ArrayList<>(); + transports.add(this.webSocketTransport); + transports.add(this.xhrTransport); + this.sockJsClient = new SockJsClient(transports); + this.sockJsClient.setInfoReceiver(this.infoReceiver); + + this.connectCallback = mock(ListenableFutureCallback.class); + } + + @Test + public void connectWebSocket() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertTrue(this.webSocketTransport.invoked()); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + verify(this.connectCallback).onSuccess(session); + verifyNoMoreInteractions(this.connectCallback); + } + + @Test + public void connectWebSocketDisabled() throws URISyntaxException { + setupInfoRequest(false); + this.sockJsClient.doHandshake(handler, URL); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr_streaming")); + } + + @Test + public void connectXhrStreamingDisabled() throws Exception { + setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr")); + } + + @Test + public void connectSockJsInfo() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + public void connectSockJsInfoCached() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + @SuppressWarnings("unchecked") + public void connectInfoRequestFailure() throws URISyntaxException { + HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE); + when(this.infoReceiver.executeInfoRequest(any())).thenThrow(exception); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + verify(this.connectCallback).onFailure(exception); + assertFalse(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + } + + private void setupInfoRequest(boolean webSocketEnabled) { + when(this.infoReceiver.executeInfoRequest(any())).thenReturn("{\"entropy\":123," + + "\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}"); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java new file mode 100644 index 0000000000..462e27b783 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Assert; +import org.junit.Test; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@code SockJsUrlInfo}. + * @author Rossen Stoyanchev + */ +public class SockJsUrlInfoTests { + + + @Test + public void serverId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + int serverId = Integer.valueOf(info.getServerId()); + assertTrue("Invalid serverId: " + serverId, serverId > 0 && serverId < 1000); + } + + @Test + public void sessionId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + assertEquals("Invalid sessionId: " + info.getSessionId(), 32, info.getSessionId().length()); + } + + @Test + public void infoUrl() throws Exception { + testInfoUrl("http", "http"); + testInfoUrl("http", "http"); + testInfoUrl("https", "https"); + testInfoUrl("https", "https"); + testInfoUrl("ws", "http"); + testInfoUrl("ws", "http"); + testInfoUrl("wss", "https"); + testInfoUrl("wss", "https"); + } + + private void testInfoUrl(String scheme, String expectedScheme) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + Assert.assertThat(info.getInfoUrl(), is(equalTo(new URI(expectedScheme + "://example.com/info")))); + } + + @Test + public void transportUrl() throws Exception { + testTransportUrl("http", "http", TransportType.XHR_STREAMING); + testTransportUrl("http", "ws", TransportType.WEBSOCKET); + testTransportUrl("https", "https", TransportType.XHR_STREAMING); + testTransportUrl("https", "wss", TransportType.WEBSOCKET); + testTransportUrl("ws", "http", TransportType.XHR_STREAMING); + testTransportUrl("ws", "ws", TransportType.WEBSOCKET); + testTransportUrl("wss", "https", TransportType.XHR_STREAMING); + testTransportUrl("wss", "wss", TransportType.WEBSOCKET); + } + + private void testTransportUrl(String scheme, String expectedScheme, TransportType transportType) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + String serverId = info.getServerId(); + String sessionId = info.getSessionId(); + String transport = transportType.toString().toLowerCase(); + URI expected = new URI(expectedScheme + "://example.com/" + serverId + "/" + sessionId + "/" + transport); + assertThat(info.getTransportUrl(transportType), equalTo(expected)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java new file mode 100644 index 0000000000..f54083b2ce --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.mockito.ArgumentCaptor; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Test SockJS Transport. + * + * @author Rossen Stoyanchev + */ +class TestTransport implements Transport { + + private final String name; + + private TransportRequest request; + + private ListenableFuture future; + + + public TestTransport(String name) { + this.name = name; + } + + public TransportRequest getRequest() { + return this.request; + } + + public boolean invoked() { + return this.future != null; + } + + @SuppressWarnings("unchecked") + public ListenableFutureCallback getConnectCallback() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ListenableFutureCallback.class); + verify(this.future).addCallback(captor.capture()); + return captor.getValue(); + } + + @SuppressWarnings("unchecked") + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + this.request = request; + this.future = mock(ListenableFuture.class); + return this.future; + } + + @Override + public String toString() { + return "TestTransport[" + name + "]"; + } + + + static class XhrTestTransport extends TestTransport implements XhrTransport { + + private boolean streamingDisabled; + + + XhrTestTransport(String name) { + super(name); + } + + public void setStreamingDisabled(boolean streamingDisabled) { + this.streamingDisabled = streamingDisabled; + } + + @Override + public boolean isXhrStreamingDisabled() { + return this.streamingDisabled; + } + + @Override + public void executeSendRequest(URI transportUrl, TextMessage message) { + } + + @Override + public String executeInfoRequest(URI infoUrl) { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java new file mode 100644 index 0000000000..d13bc207a1 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.notNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class XhrTransportTests { + + @Test + public void infoResponse() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test(expected = HttpServerErrorException.class) + public void infoResponseError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test + public void sendMessage() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + assertEquals(2, transport.actualSendRequestHeaders.size()); + assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo")); + assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType()); + } + + @Test(expected = HttpServerErrorException.class) + public void sendMessageError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + } + + @Test + public void connect() throws Exception { + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.setOrigin("foo"); + + TransportRequest request = mock(TransportRequest.class); + when(request.getSockJsUrlInfo()).thenReturn(new SockJsUrlInfo(new URI("http://example.com"))); + when(request.getHandshakeHeaders()).thenReturn(handshakeHeaders); + + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + + WebSocketHandler handler = mock(WebSocketHandler.class); + transport.connect(request, handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(request).getSockJsUrlInfo(); + verify(request).addTimeoutTask(captor.capture()); + verify(request).getTransportUrl(); + verify(request).getHandshakeHeaders(); + verifyNoMoreInteractions(request); + + assertEquals(2, transport.actualHandshakeHeaders.size()); + assertEquals("foo", transport.actualHandshakeHeaders.getOrigin()); + assertEquals("bar", transport.actualHandshakeHeaders.getFirst("foo")); + + assertFalse(transport.actualSession.isDisconnected()); + captor.getValue().run(); + assertTrue(transport.actualSession.isDisconnected()); + } + + + private static class TestXhrTransport extends AbstractXhrTransport { + + private ResponseEntity infoResponseToReturn; + + private ResponseEntity sendMessageResponseToReturn; + + private HttpHeaders actualSendRequestHeaders; + + private HttpHeaders actualHandshakeHeaders; + + private XhrClientSockJsSession actualSession; + + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return this.infoResponseToReturn; + } + + @Override + protected ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + this.actualSendRequestHeaders = headers; + return this.sendMessageResponseToReturn; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl, + HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + this.actualHandshakeHeaders = handshakeHeaders; + this.actualSession = session; + } + } + +} diff --git a/spring-websocket/src/test/resources/log4j.properties b/spring-websocket/src/test/resources/log4j.properties index 8db186fb4e..0b8d9ec5f6 100644 --- a/spring-websocket/src/test/resources/log4j.properties +++ b/spring-websocket/src/test/resources/log4j.properties @@ -1,9 +1,9 @@ log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c] - %m%n +log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c][%t] - %m%n log4j.rootCategory=WARN, console log4j.logger.org.springframework.web=DEBUG -log4j.logger.org.springframework.web.socket=DEBUG +log4j.logger.org.springframework.web.socket=TRACE log4j.logger.org.springframework.messaging=DEBUG