From 319f18dddf885ee10bd8a167a88e8d574e546e15 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 13 Aug 2013 17:23:33 -0400 Subject: [PATCH] Add HandshakeInterceptor A HandshakeInterceptor can be used to intercept WebSocket handshakes (or SockJS requests where a new session is created) in order to inspect the request and response before and after the handshake including the ability to pass attributes to the WebSocketHandler, which the hander can access through WebSocketSession.getHandshakeAttributes() An HttpSessionHandshakeInterceptor is available that can copy attributes from the HTTP session to make them available to the WebSocket session. Issue: SPR-10624 --- .../web/socket/WebSocketSession.java | 8 + .../adapter/AbstractWebSocketSesssion.java | 19 +++ .../adapter/DelegatingWebSocketSession.java | 4 +- .../socket/adapter/JettyWebSocketSession.java | 6 +- .../adapter/StandardWebSocketSession.java | 16 +- .../client/AbstractWebSocketClient.java | 14 +- .../endpoint/StandardWebSocketClient.java | 8 +- .../client/jetty/JettyWebSocketClient.java | 22 ++- .../server/DefaultHandshakeHandler.java | 5 +- .../web/socket/server/HandshakeHandler.java | 11 +- .../socket/server/HandshakeInterceptor.java | 66 ++++++++ .../socket/server/RequestUpgradeStrategy.java | 11 +- .../AbstractStandardUpgradeStrategy.java | 12 +- .../support/HandshakeInterceptorChain.java | 81 ++++++++++ .../HttpSessionHandshakeInterceptor.java | 99 ++++++++++++ .../support/JettyRequestUpgradeStrategy.java | 5 +- .../support/WebSocketHttpRequestHandler.java | 57 ++++++- .../sockjs/support/AbstractSockJsService.java | 8 +- .../handler/DefaultSockJsService.java | 148 +++++++++++++----- .../handler/EventSourceTransportHandler.java | 13 +- .../handler/HtmlFileTransportHandler.java | 13 +- .../handler/JsonpPollingTransportHandler.java | 7 +- .../handler/SockJsSessionFactory.java | 11 +- .../handler/WebSocketTransportHandler.java | 16 +- .../handler/XhrPollingTransportHandler.java | 7 +- .../handler/XhrStreamingTransportHandler.java | 13 +- .../session/AbstractHttpSockJsSession.java | 22 ++- .../session/AbstractSockJsSession.java | 47 +++--- .../session/PollingSockJsSession.java | 9 +- .../session/StreamingSockJsSession.java | 7 +- .../session/WebSocketServerSockJsSession.java | 15 +- .../JettyWebSocketHandlerAdapterTests.java | 2 +- .../StandardWebSocketHandlerAdapterTests.java | 2 +- .../jetty/JettyWebSocketClientTests.java | 2 +- .../server/DefaultHandshakeHandlerTests.java | 8 +- .../HandshakeInterceptorChainTests.java | 101 ++++++++++++ .../HttpSessionHandshakeInterceptorTests.java | 86 ++++++++++ .../handler/DefaultSockJsServiceTests.java | 6 +- .../HttpReceivingTransportHandlerTests.java | 6 +- .../HttpSendingTransportHandlerTests.java | 10 +- .../AbstractHttpSockJsSessionTests.java | 9 +- .../session/AbstractSockJsSessionTests.java | 6 +- .../session/TestHttpSockJsSession.java | 7 +- .../transport/session/TestSockJsSession.java | 19 ++- .../WebSocketServerSockJsSessionTests.java | 10 +- .../socket/support/TestWebSocketSession.java | 19 +++ 46 files changed, 903 insertions(+), 170 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index bf14730e9c..b935e21889 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; @@ -48,6 +49,13 @@ public interface WebSocketSession { */ HttpHeaders getHandshakeHeaders(); + /** + * Handshake request specific attributes. + * To add attributes to a server-side WebSocket session see + * {@link org.springframework.web.socket.server.HandshakeInterceptor}. + */ + Map getHandshakeAttributes(); + /** * Return a {@link java.security.Principal} instance containing the name of the * authenticated user. If the user has not been authenticated, the method returns diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java index 2faffbfb13..fc3fe82884 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java @@ -16,6 +16,7 @@ package org.springframework.web.socket.adapter; import java.io.IOException; +import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -38,6 +39,24 @@ public abstract class AbstractWebSocketSesssion implements DelegatingWebSocke private T delegateSession; + private final Map handshakeAttributes; + + + /** + * Class constructor + * + * @param handshakeAttributes attributes from the HTTP handshake to make available + * through the WebSocket session + */ + public AbstractWebSocketSesssion(Map handshakeAttributes) { + this.handshakeAttributes = handshakeAttributes; + } + + + @Override + public Map getHandshakeAttributes() { + return this.handshakeAttributes; + } /** * @return the WebSocket session to delegate to diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java index 73260ca668..807eed1908 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java @@ -20,8 +20,8 @@ import org.springframework.web.socket.WebSocketSession; /** - * A contract for {@link WebSocketSession} implementations that delegate to another - * WebSocket session (e.g. a native session). + * A contract for a {@link WebSocketSession} that delegates to another WebSocket session + * (e.g. a native session). * * @param T the type of the delegate WebSocket session * diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index 708ac66bc1..e67d94fc51 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; @@ -46,8 +47,11 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion handshakeAttributes) { + super(handshakeAttributes); this.principal = principal; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 8dd2c1fd38..f59ac55b27 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; @@ -39,7 +40,7 @@ import org.springframework.web.socket.WebSocketSession; */ public class StandardWebSocketSession extends AbstractWebSocketSesssion { - private final HttpHeaders headers; + private final HttpHeaders handshakeHeaders; private final InetSocketAddress localAddress; @@ -50,12 +51,17 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion handshakeAttributes, + InetSocketAddress localAddress, InetSocketAddress remoteAddress) { + super(handshakeAttributes); handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders(); - this.headers = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); + this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); this.localAddress = localAddress; this.remoteAddress = remoteAddress; } @@ -74,7 +80,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssionemptyMap()); } /** - * + * Perform the actual handshake to establish a connection to the server. * * @param webSocketHandler the client-side handler for WebSocket messages * @param headers HTTP headers to use for the handshake, with unwanted (forbidden) * headers filtered out, never {@code null} * @param uri the target URI for the handshake, never {@code null} * @param subProtocols requested sub-protocols, or an empty list + * @param handshakeAttributes attributes to make available via + * {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map. + * * @return the established WebSocket session + * * @throws WebSocketConnectFailureException */ protected abstract WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List subProtocols) throws WebSocketConnectFailureException; + HttpHeaders headers, URI uri, List subProtocols, + Map handshakeAttributes) throws WebSocketConnectFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index 931262d15f..a6ff22586f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -63,14 +63,16 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { @Override - protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders headers, URI uri, List protocols) throws WebSocketConnectFailureException { + protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers, + URI uri, List protocols, Map handshakeAttributes) + throws WebSocketConnectFailureException { int port = getPort(uri); InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); - StandardWebSocketSession session = new StandardWebSocketSession(headers, localAddress, remoteAddress); + StandardWebSocketSession session = new StandardWebSocketSession(headers, + handshakeAttributes, localAddress, remoteAddress); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); configBuidler.configurator(new StandardWebSocketClientConfigurator(headers)); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 302b66ab33..159d683e9e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -17,7 +17,9 @@ package org.springframework.web.socket.client.jetty; import java.net.URI; +import java.security.Principal; import java.util.List; +import java.util.Map; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.springframework.context.SmartLifecycle; @@ -122,25 +124,26 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma } @Override - public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables) + public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVars) throws WebSocketConnectFailureException { - UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode(); + UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode(); return doHandshake(webSocketHandler, null, uriComponents.toUri()); } @Override public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers, - URI uri, List protocols) throws WebSocketConnectFailureException { + URI uri, List protocols, Map handshakeAttributes) + throws WebSocketConnectFailureException { ClientUpgradeRequest request = new ClientUpgradeRequest(); request.setSubProtocols(protocols); - for (String header : headers.keySet()) { request.setHeader(header, headers.get(header)); } - JettyWebSocketSession wsSession = new JettyWebSocketSession(null); + Principal user = getUser(); + JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes); JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); try { @@ -153,4 +156,13 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma } } + + /** + * @return the user to make available through {@link WebSocketSession#getPrincipal()}; + * by default this method returns {@code null} + */ + protected Principal getUser() { + return null; + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 87587e9dbd..660dac28ea 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -23,6 +23,7 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import javax.xml.bind.DatatypeConverter; @@ -98,7 +99,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @Override public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws IOException, HandshakeFailureException { + WebSocketHandler webSocketHandler, Map attributes) throws IOException, HandshakeFailureException { logger.debug("Starting handshake for " + request.getURI()); @@ -150,7 +151,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.trace("Upgrading with " + webSocketHandler); } - this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler); + this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler, attributes); return true; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java index 75654847ce..a463961fc3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -29,7 +30,9 @@ import org.springframework.web.socket.support.PerConnectionWebSocketHandler; * @author Rossen Stoyanchev * @since 4.0 * + * @see HandshakeInterceptor * @see org.springframework.web.socket.server.support.WebSocketHttpRequestHandler + * @see org.springframework.web.socket.sockjs.SockJsService */ public interface HandshakeHandler { @@ -38,9 +41,11 @@ public interface HandshakeHandler { * * @param request the current request * @param response the current response - * @param webSocketHandler the handler to process WebSocket messages; see + * @param wsHandler the handler to process WebSocket messages; see * {@link PerConnectionWebSocketHandler} for providing a handler with * per-connection lifecycle. + * @param attributes handshake request specific attributes to be set on the WebSocket + * session and thus made available to the {@link WebSocketHandler} * * @return whether the handshake negotiation was successful or not. In either case the * response status, headers, and body will have been updated to reflect the @@ -53,7 +58,7 @@ public interface HandshakeHandler { * opposed to a failure to successfully negotiate the requirements of the * handshake request. */ - boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler) - throws IOException, HandshakeFailureException; + boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Map attributes) throws IOException, HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java new file mode 100644 index 0000000000..86f7e0f3fb --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/HandshakeInterceptor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server; + +import java.util.Map; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + + +/** + * Interceptor for WebSocket handshake requests. Can be used to inspect the handshake + * request and response as well as to pass attributes to the target + * {@link WebSocketHandler}. + * + * @author Rossen Stoyanchev + * @since 4.0 + * + * @see org.springframework.web.socket.server.support.WebSocketHttpRequestHandler + * @see org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService + */ +public interface HandshakeInterceptor { + + /** + * Invoked before the handshake is processed. + * + * @param request the current request + * @param response the current response + * @param wsHandler the target WebSocket handler + * @param attributes attributes to make available via + * {@link WebSocketSession#getHandshakeAttributes()} + * + * @return whether to proceed with the handshake {@code true} or abort {@code false} + */ + boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception; + + /** + * Invoked after the handshake is done. The response status and headers indicate the + * results of the handshake, i.e. whether it was successful or not. + * + * @param request the current request + * @param response the current response + * @param wsHandler the target WebSocket handler + * @param exception an exception raised during the handshake, or {@code null} + */ + void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index 1421f51984..3dc5d83aed 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -40,14 +41,18 @@ public interface RequestUpgradeStrategy { * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. * - * @param webSocketHandler the handler for WebSocket messages + * @param request the current request + * @param response the current response + * @param acceptedProtocol the accepted sub-protocol, if any + * @param wsHandler the handler for WebSocket messages + * @param attributes handshake context attributes * * @throws HandshakeFailureException thrown when handshake processing failed to * complete due to an internal, unrecoverable error, i.e. a server error as * opposed to a failure to successfully negotiate the requirements of the * handshake request. */ - void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, - WebSocketHandler webSocketHandler) throws IOException, HandshakeFailureException; + void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + WebSocketHandler wsHandler, Map attributes) throws IOException, HandshakeFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 0a02e47761..127e651169 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Map; import javax.websocket.Endpoint; @@ -44,14 +45,15 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String acceptedProtocol, WebSocketHandler wsHandler) throws IOException, HandshakeFailureException { + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, + WebSocketHandler wsHandler, Map attributes) + throws IOException, HandshakeFailureException { HttpHeaders headers = request.getHeaders(); - InetSocketAddress localAddress = request.getLocalAddress(); - InetSocketAddress remoteAddress = request.getRemoteAddress(); + InetSocketAddress localAddr = request.getLocalAddress(); + InetSocketAddress remoteAddr = request.getRemoteAddress(); - StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, localAddress, remoteAddress); + StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, attributes, localAddr, remoteAddr); StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, wsSession); upgradeInternal(request, response, acceptedProtocol, endpoint); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java new file mode 100644 index 0000000000..76dc2a82f5 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HandshakeInterceptorChain.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + + +/** + * A helper class that assists with invoking a list of handshake interceptors. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class HandshakeInterceptorChain { + + private static final Log logger = LogFactory.getLog(WebSocketHttpRequestHandler.class); + + private final List interceptors; + + private final WebSocketHandler wsHandler; + + private int interceptorIndex = -1; + + + public HandshakeInterceptorChain(List interceptors, WebSocketHandler wsHandler) { + this.interceptors = (interceptors != null) ? interceptors : Collections.emptyList(); + this.wsHandler = wsHandler; + } + + + public boolean applyBeforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + Map attributes) throws Exception { + + for (int i = 0; i < this.interceptors.size(); i++) { + HandshakeInterceptor interceptor = this.interceptors.get(i); + if (!interceptor.beforeHandshake(request, response, this.wsHandler, attributes)) { + applyAfterHandshake(request, response, null); + return false; + } + this.interceptorIndex = i; + } + return true; + } + + + public void applyAfterHandshake(ServerHttpRequest request, ServerHttpResponse response, Exception failure) { + for (int i = this.interceptorIndex; i >= 0; i--) { + HandshakeInterceptor interceptor = this.interceptors.get(i); + try { + interceptor.afterHandshake(request, response, this.wsHandler, failure); + } + catch (Throwable t) { + logger.warn("HandshakeInterceptor afterHandshake threw exception " + t); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java new file mode 100644 index 0000000000..a23a7afbe0 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collection; +import java.util.Enumeration; +import java.util.Map; + +import javax.servlet.http.HttpSession; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.util.CollectionUtils; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.server.HandshakeInterceptor; + + +/** + * An interceptor to copy HTTP session attributes into the map of "handshake attributes" + * made available through {@link WebSocketSession#getHandshakeAttributes()}. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { + + private static Log logger = LogFactory.getLog(HttpSessionHandshakeInterceptor.class); + + private Collection attributeNames; + + + /** + * A constructor for copying all available HTTP session attributes. + */ + public HttpSessionHandshakeInterceptor() { + this(null); + } + + /** + * A constructor for copying a subset of HTTP session attributes. + * @param attributeNames the HTTP session attributes to copy + */ + public HttpSessionHandshakeInterceptor(Collection attributeNames) { + this.attributeNames = attributeNames; + } + + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception { + + if (request instanceof ServletServerHttpRequest) { + ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; + HttpSession session = servletRequest.getServletRequest().getSession(false); + if (session != null) { + Enumeration names = session.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) { + if (logger.isTraceEnabled()) { + logger.trace("Adding HTTP session attribute to handshake attributes: " + name); + } + attributes.put(name, session.getAttribute(name)); + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Skipped HTTP session attribute"); + } + } + } + } + } + return true; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception ex) { + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index e6b30aa0d3..184033ca0a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -85,7 +86,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler wsHandler) throws IOException { + String protocol, WebSocketHandler wsHandler, Map attrs) throws IOException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -98,7 +99,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { throw new HandshakeFailureException("Not a WebSocket request"); } - JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal()); + JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal(), attrs); JettyWebSocketHandlerAdapter wsListener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, wsListener); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java index c3caa213d8..95df66ad68 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java @@ -17,6 +17,10 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -30,7 +34,9 @@ import org.springframework.util.Assert; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; @@ -56,6 +62,8 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { private final WebSocketHandler webSocketHandler; + private final List interceptors = new ArrayList(); + public WebSocketHttpRequestHandler(WebSocketHandler webSocketHandler) { this(webSocketHandler, new DefaultHandshakeHandler()); @@ -69,6 +77,23 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { } + /** + * Configure one or more WebSocket handshake request interceptors. + */ + public void setHandshakeInterceptors(List interceptors) { + this.interceptors.clear(); + if (interceptors != null) { + this.interceptors.addAll(interceptors); + } + } + + /** + * Return the configured WebSocket handshake request interceptors. + */ + public List getHandshakeInterceptors() { + return this.interceptors; + } + /** * Decorate the WebSocketHandler provided to the class constructor. * @@ -81,14 +106,36 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler { } @Override - public void handleRequest(HttpServletRequest request, HttpServletResponse response) + public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse) throws ServletException, IOException { - ServerHttpRequest httpRequest = new ServletServerHttpRequest(request); - ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + ServerHttpResponse response = new ServletServerHttpResponse(servletResponse); - this.handshakeHandler.doHandshake(httpRequest, httpResponse, this.webSocketHandler); - httpResponse.flush(); + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, this.webSocketHandler); + HandshakeFailureException failure = null; + + try { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + this.handshakeHandler.doHandshake(request, response, this.webSocketHandler, attributes); + chain.applyAfterHandshake(request, response, null); + } + catch (HandshakeFailureException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; + } + response.flush(); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 44597106f9..7a1fbc1502 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -269,8 +269,8 @@ public abstract class AbstractSockJsService implements SockJsService { * and raw WebSocket requests are delegated to abstract methods. */ @Override - public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) - throws SockJsException { + public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler) throws SockJsException { String sockJsPath = getSockJsPath(request); if (sockJsPath == null) { @@ -301,7 +301,7 @@ public abstract class AbstractSockJsService implements SockJsService { this.iframeHandler.handle(request, response); } else if (sockJsPath.equals("/websocket")) { - handleRawWebSocketRequest(request, response, handler); + handleRawWebSocketRequest(request, response, wsHandler); } else { String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/"); @@ -318,7 +318,7 @@ public abstract class AbstractSockJsService implements SockJsService { response.setStatusCode(HttpStatus.NOT_FOUND); return; } - handleTransportRequest(request, response, handler, sessionId, transport); + handleTransportRequest(request, response, wsHandler, sessionId, transport); } response.flush(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java index a64ff7c6a6..6d6033c960 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -41,7 +42,10 @@ import org.springframework.util.ObjectUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.HandshakeInterceptorChain; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.AbstractSockJsService; @@ -75,6 +79,8 @@ public class DefaultSockJsService extends AbstractSockJsService { private SockJsMessageCodec messageCodec; + private final List interceptors = new ArrayList(); + private final Map sessions = new ConcurrentHashMap(); private ScheduledFuture sessionCleanupTask; @@ -167,6 +173,23 @@ public class DefaultSockJsService extends AbstractSockJsService { } + /** + * Configure one or more WebSocket handshake request interceptors. + */ + public void setHandshakeInterceptors(List interceptors) { + this.interceptors.clear(); + if (interceptors != null) { + this.interceptors.addAll(interceptors); + } + } + + /** + * Return the configured WebSocket handshake request interceptors. + */ + public List getHandshakeInterceptors() { + return this.interceptors; + } + /** * The codec to use for encoding and decoding SockJS messages. * @exception IllegalStateException if no {@link SockJsMessageCodec} is available @@ -185,19 +208,42 @@ public class DefaultSockJsService extends AbstractSockJsService { @Override protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws IOException { + WebSocketHandler wsHandler) throws IOException { - if (isWebSocketEnabled()) { - TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); - if (transportHandler != null) { - if (transportHandler instanceof HandshakeHandler) { - ((HandshakeHandler) transportHandler).doHandshake(request, response, webSocketHandler); - return; - } - } - logger.warn("No handler for raw WebSocket messages"); + if (!isWebSocketEnabled()) { + return; + } + + TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); + if ((transportHandler == null) || !(transportHandler instanceof HandshakeHandler)) { + logger.warn("No handler for raw WebSocket messages"); + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, wsHandler); + HandshakeFailureException failure = null; + + try { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + ((HandshakeHandler) transportHandler).doHandshake(request, response, wsHandler, attributes); + chain.applyAfterHandshake(request, response, null); + } + catch (HandshakeFailureException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new HandshakeFailureException("Uncaught failure for request " + request.getURI(), t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; + } } - response.setStatusCode(HttpStatus.NOT_FOUND); } @Override @@ -235,38 +281,61 @@ public class DefaultSockJsService extends AbstractSockJsService { return; } - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { - if (transportHandler instanceof SockJsSessionFactory) { - SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; - session = createSockJsSession(sessionId, sessionFactory, wsHandler, request, response); + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, wsHandler); + SockJsException failure = null; + + try { + WebSocketSession session = this.sessions.get(sessionId); + if (session == null) { + if (transportHandler instanceof SockJsSessionFactory) { + Map attributes = new HashMap(); + if (!chain.applyBeforeHandshake(request, response, attributes)) { + return; + } + SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; + session = createSockJsSession(sessionId, sessionFactory, wsHandler, attributes, request, response); + } + else { + response.setStatusCode(HttpStatus.NOT_FOUND); + logger.warn("Session not found"); + return; + } + } + + if (transportType.sendsNoCacheInstruction()) { + addNoCacheHeaders(response); + } + + if (transportType.sendsSessionCookie() && isDummySessionCookieEnabled()) { + Cookie cookie = request.getCookies().get("JSESSIONID"); + String value = (cookie != null) ? cookie.getValue() : "dummy"; + response.getHeaders().set("Set-Cookie", "JSESSIONID=" + value + ";path=/"); + } + + if (transportType.supportsCors()) { + addCorsHeaders(request, response); + } + + transportHandler.handleRequest(request, response, wsHandler, session); + chain.applyAfterHandshake(request, response, null); + } + catch (SockJsException ex) { + failure = ex; + } + catch (Throwable t) { + failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, t); + } + finally { + if (failure != null) { + chain.applyAfterHandshake(request, response, failure); + throw failure; } } - if (session == null) { - response.setStatusCode(HttpStatus.NOT_FOUND); - logger.warn("Session not found"); - return; - } - - if (transportType.sendsNoCacheInstruction()) { - addNoCacheHeaders(response); - } - - if (transportType.sendsSessionCookie() && isDummySessionCookieEnabled()) { - Cookie cookie = request.getCookies().get("JSESSIONID"); - String value = (cookie != null) ? cookie.getValue() : "dummy"; - response.getHeaders().set("Set-Cookie", "JSESSIONID=" + value + ";path=/"); - } - - if (transportType.supportsCors()) { - addCorsHeaders(request, response); - } - - transportHandler.handleRequest(request, response, wsHandler, session); } private WebSocketSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, - WebSocketHandler handler, ServerHttpRequest request, ServerHttpResponse response) { + WebSocketHandler wsHandler, Map handshakeAttributes, + ServerHttpRequest request, ServerHttpResponse response) { synchronized (this.sessions) { AbstractSockJsSession session = this.sessions.get(sessionId); @@ -276,8 +345,9 @@ public class DefaultSockJsService extends AbstractSockJsService { if (this.sessionCleanupTask == null) { scheduleSessionTask(); } + logger.debug("Creating new session with session id \"" + sessionId + "\""); - session = sessionFactory.createSession(sessionId, handler); + session = sessionFactory.createSession(sessionId, wsHandler, handshakeAttributes); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java index 082a8e181a..f13b3f8ba4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -49,8 +50,10 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new EventSourceStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new EventSourceStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -61,8 +64,10 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan private final class EventSourceStreamingSockJsSession extends StreamingSockJsSession { - private EventSourceStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private EventSourceStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index e09df75994..980afa4fb1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -87,8 +88,10 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new HtmlFileStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new HtmlFileStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -124,8 +127,10 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle private final class HtmlFileStreamingSockJsSession extends StreamingSockJsSession { - private HtmlFileStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private HtmlFileStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java index 0d8cea2dde..5c501646ef 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -53,8 +54,10 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public PollingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public PollingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java index 06d54ae09e..ae15c84788 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsSessionFactory.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.handler; +import java.util.Map; + import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.session.AbstractSockJsSession; @@ -31,10 +33,13 @@ public interface SockJsSessionFactory { /** * Create a new SockJS session. + * * @param sessionId the ID of the session - * @param webSocketHandler the underlying {@link WebSocketHandler} - * @return a new non-null session + * @param wsHandler the underlying {@link WebSocketHandler} + * @param attributes handshake request specific attributes + * + * @return a new session, never {@code null} */ - AbstractSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler); + AbstractSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, Map attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java index f6d40b33b3..788cb93bd6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -60,8 +62,10 @@ public class WebSocketTransportHandler extends TransportHandlerSupport } @Override - public AbstractSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler) { - return new WebSocketServerSockJsSession(sessionId, getSockJsServiceConfig(), webSocketHandler); + public AbstractSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new WebSocketServerSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -71,7 +75,7 @@ public class WebSocketTransportHandler extends TransportHandlerSupport WebSocketServerSockJsSession sockJsSession = (WebSocketServerSockJsSession) wsSession; try { wsHandler = new SockJsWebSocketHandler(getSockJsServiceConfig(), wsHandler, sockJsSession); - this.handshakeHandler.doHandshake(request, response, wsHandler); + this.handshakeHandler.doHandshake(request, response, wsHandler, Collections.emptyMap()); } catch (Throwable t) { sockJsSession.tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR); @@ -82,10 +86,10 @@ public class WebSocketTransportHandler extends TransportHandlerSupport // HandshakeHandler methods @Override - public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) - throws IOException { + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler handler, Map attributes) throws IOException { - return this.handshakeHandler.doHandshake(request, response, handler); + return this.handshakeHandler.doHandshake(request, response, handler, attributes); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java index ed59a339e2..832525dfa2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -51,8 +52,10 @@ public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHand } @Override - public PollingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public PollingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new PollingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java index 307471c705..03da01986f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.io.IOException; import java.nio.charset.Charset; +import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; @@ -49,8 +50,10 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public StreamingSockJsSession createSession(String sessionId, WebSocketHandler handler) { - return new XhrStreamingSockJsSession(sessionId, getSockJsServiceConfig(), handler); + public StreamingSockJsSession createSession(String sessionId, WebSocketHandler wsHandler, + Map attributes) { + + return new XhrStreamingSockJsSession(sessionId, getSockJsServiceConfig(), wsHandler, attributes); } @Override @@ -61,8 +64,10 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa private final class XhrStreamingSockJsSession extends StreamingSockJsSession { - private XhrStreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private XhrStreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 8f85b4ca2e..f3a0c9d347 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; +import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -52,7 +54,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private ServerHttpAsyncRequestControl asyncRequestControl; - private String protocol; + private URI uri; private HttpHeaders handshakeHeaders; @@ -62,12 +64,21 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private InetSocketAddress remoteAddress; + private String acceptedProtocol; - public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { - super(id, config, wsHandler); + + public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map handshakeAttributes) { + + super(id, config, wsHandler, handshakeAttributes); } + @Override + public URI getUri() { + return this.uri; + } + @Override public HttpHeaders getHandshakeHeaders() { return this.handshakeHeaders; @@ -112,14 +123,14 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { * @param protocol the sub-protocol to set */ public void setAcceptedProtocol(String protocol) { - this.protocol = protocol; + this.acceptedProtocol = protocol; } /** * Return the selected sub-protocol to use. */ public String getAcceptedProtocol() { - return this.protocol; + return this.acceptedProtocol; } public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, @@ -135,6 +146,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t); } + this.uri = request.getURI(); this.handshakeHeaders = request.getHeaders(); this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index 0a94cce68c..1fbd5e1de3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -19,12 +19,11 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.EOFException; import java.io.IOException; import java.net.SocketException; -import java.net.URI; -import java.security.Principal; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledFuture; import org.apache.commons.logging.Log; @@ -51,18 +50,12 @@ public abstract class AbstractSockJsSession implements WebSocketSession { private final String id; - private URI uri; - - private String remoteHostName; - - private String remoteAddress; - - private Principal principal; - - private final SockJsServiceConfig sockJsServiceConfig; + private final SockJsServiceConfig config; private final WebSocketHandler handler; + private final Map handshakeAttributes; + private State state = State.NEW; private final long timeCreated = System.currentTimeMillis(); @@ -73,17 +66,21 @@ public abstract class AbstractSockJsSession implements WebSocketSession { /** - * @param sessionId the session ID + * @param id the session ID * @param config SockJS service configuration options * @param wsHandler the recipient of SockJS messages */ - public AbstractSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler) { - Assert.notNull(sessionId, "sessionId is required"); + public AbstractSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map handshakeAttributes) { + + Assert.notNull(id, "sessionId is required"); Assert.notNull(config, "sockJsConfig is required"); Assert.notNull(wsHandler, "webSocketHandler is required"); - this.id = sessionId; - this.sockJsServiceConfig = config; + + this.id = id; + this.config = config; this.handler = wsHandler; + this.handshakeAttributes = handshakeAttributes; } @Override @@ -91,13 +88,13 @@ public abstract class AbstractSockJsSession implements WebSocketSession { return this.id; } - @Override - public URI getUri() { - return this.uri; + public SockJsServiceConfig getSockJsServiceConfig() { + return this.config; } - public SockJsServiceConfig getSockJsServiceConfig() { - return this.sockJsServiceConfig; + @Override + public Map getHandshakeAttributes() { + return this.handshakeAttributes; } public boolean isNew() { @@ -306,13 +303,13 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } protected void scheduleHeartbeat() { - Assert.state(this.sockJsServiceConfig.getTaskScheduler() != null, "heartbeatScheduler not configured"); + Assert.state(this.config.getTaskScheduler() != null, "heartbeatScheduler not configured"); cancelHeartbeat(); if (!isActive()) { return; } - Date time = new Date(System.currentTimeMillis() + this.sockJsServiceConfig.getHeartbeatTime()); - this.heartbeatTask = this.sockJsServiceConfig.getTaskScheduler().schedule(new Runnable() { + Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); + this.heartbeatTask = this.config.getTaskScheduler().schedule(new Runnable() { public void run() { try { sendHeartbeat(); @@ -323,7 +320,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } }, time); if (logger.isTraceEnabled()) { - logger.trace("Scheduled heartbeat after " + this.sockJsServiceConfig.getHeartbeatTime()/1000 + " seconds"); + logger.trace("Scheduled heartbeat after " + this.config.getHeartbeatTime()/1000 + " seconds"); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java index 9fe6a5b8c7..1f1e4668f5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.session; +import java.util.Map; + import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -28,8 +30,11 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec; */ public class PollingSockJsSession extends AbstractHttpSockJsSession { - public PollingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + + public PollingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 7ae5770e44..017bc67f2c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -37,8 +38,10 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { private int byteCount; - public StreamingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public StreamingSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 3f93b23cbc..ac3cbfcf6a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; @@ -44,8 +46,17 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession private WebSocketSession wsSession; - public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { - super(id, config, wsHandler); + public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(id, config, wsHandler, attributes); + } + + + @Override + public URI getUri() { + checkDelegateSessionInitialized(); + return this.wsSession.getUri(); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java index a964543b15..cd797a92aa 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java @@ -44,7 +44,7 @@ public class JettyWebSocketHandlerAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new JettyWebSocketSession(null); + this.webSocketSession = new JettyWebSocketSession(null, null); this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java index ede2a5e951..6a3b071adf 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java @@ -50,7 +50,7 @@ public class StandardWebSocketHandlerAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new StandardWebSocketSession(null, null, null); + this.webSocketSession = new StandardWebSocketSession(null, null, null, null); this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index 2eaa55054e..a761fd9c54 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -113,7 +113,7 @@ public class JettyWebSocketClientTests { resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); } - JettyWebSocketSession session = new JettyWebSocketSession(null); + JettyWebSocketSession session = new JettyWebSocketSession(null, null); return new JettyWebSocketHandlerAdapter(webSocketHandler, session); } }); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index c47ed6906d..537b8a524d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -16,6 +16,9 @@ package org.springframework.web.socket.server; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -62,10 +65,11 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { this.request.getHeaders().setSecWebSocketProtocol("STOMP"); WebSocketHandler handler = new TextWebSocketHandlerAdapter(); + Map attributes = Collections.emptyMap(); - this.handshakeHandler.doHandshake(this.request, this.response, handler); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); - verify(this.upgradeStrategy).upgrade(request, response, "STOMP", handler); + verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", handler, attributes); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java new file mode 100644 index 0000000000..e0f28e5f92 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HandshakeInterceptorChainTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +import static org.mockito.Mockito.*; + + +/** + * Test fixture for {@link HandshakeInterceptorChain}. + * + * @author Rossen Stoyanchev + */ +public class HandshakeInterceptorChainTests extends AbstractHttpRequestTests { + + private HandshakeInterceptor i1; + + private HandshakeInterceptor i2; + + private HandshakeInterceptor i3; + + private List interceptors; + + private WebSocketHandler wsHandler; + + private Map attributes; + + + @Before + public void setup() { + i1 = mock(HandshakeInterceptor.class); + i2 = mock(HandshakeInterceptor.class); + i3 = mock(HandshakeInterceptor.class); + interceptors = Arrays.asList(i1, i2, i3); + wsHandler = mock(WebSocketHandler.class); + attributes = new HashMap(); + } + + + @Test + public void success() throws Exception { + when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i3.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyBeforeHandshake(request, response, attributes); + + verify(i1).beforeHandshake(request, response, wsHandler, attributes); + verify(i2).beforeHandshake(request, response, wsHandler, attributes); + verify(i3).beforeHandshake(request, response, wsHandler, attributes); + verifyNoMoreInteractions(i1, i2, i3); + } + + @Test + public void applyBeforeHandshakeWithFalseReturnValue() throws Exception { + when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true); + when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(false); + + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyBeforeHandshake(request, response, attributes); + + verify(i1).beforeHandshake(request, response, wsHandler, attributes); + verify(i1).afterHandshake(request, response, wsHandler, null); + verify(i2).beforeHandshake(request, response, wsHandler, attributes); + verifyNoMoreInteractions(i1, i2, i3); + } + + @Test + public void applyAfterHandshakeOnly() { + HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler); + chain.applyAfterHandshake(request, response, null); + + verifyNoMoreInteractions(i1, i2, i3); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java new file mode 100644 index 0000000000..bc51fd5d77 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link HttpSessionHandshakeInterceptor}. + * + * @author Rossen Stoyanchev + */ +public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTests { + + + @Test + public void copyAllAttributes() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.getSession().setAttribute("foo", "bar"); + this.servletRequest.getSession().setAttribute("bar", "baz"); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertEquals(2, attributes.size()); + assertEquals("bar", attributes.get("foo")); + assertEquals("baz", attributes.get("bar")); + } + + @Test + public void copySelectedAttributes() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.getSession().setAttribute("foo", "bar"); + this.servletRequest.getSession().setAttribute("bar", "baz"); + + Set names = Collections.singleton("foo"); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertEquals(1, attributes.size()); + assertEquals("bar", attributes.get("foo")); + } + + @Test + public void doNotCauseSessionCreation() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertNull(this.servletRequest.getSession(false)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 85c3540e39..ba1229a92b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.handler; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -70,10 +71,11 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { MockitoAnnotations.initMocks(this); - this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler); + Map attributes = Collections.emptyMap(); + this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler, attributes); when(this.xhrHandler.getTransportType()).thenReturn(TransportType.XHR); - when(this.xhrHandler.createSession(sessionId, this.wsHandler)).thenReturn(this.session); + when(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).thenReturn(this.session); when(this.xhrSendHandler.getTransportType()).thenReturn(TransportType.XHR_SEND); this.service = new DefaultSockJsService(this.taskScheduler, diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java index 0f3f54ce2e..ef8396abd1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpReceivingTransportHandlerTests.java @@ -107,7 +107,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest this.servletRequest.setContent("[\"x\"]".getBytes("UTF-8")); WebSocketHandler wsHandler = mock(WebSocketHandler.class); - TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler); + TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler, null); session.delegateConnectionEstablished(); doThrow(new Exception()).when(wsHandler).handleMessage(session, new TextMessage("x")); @@ -127,7 +127,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest private void handleRequest(AbstractHttpReceivingTransportHandler transportHandler) throws Exception { WebSocketHandler wsHandler = mock(WebSocketHandler.class); - AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler); + AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null); transportHandler.setSockJsServiceConfiguration(new StubSockJsServiceConfig()); transportHandler.handleRequest(this.request, this.response, wsHandler, session); @@ -141,7 +141,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest resetResponse(); WebSocketHandler wsHandler = mock(WebSocketHandler.class); - AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler); + AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null); new XhrReceivingTransportHandler().handleRequest(this.request, this.response, wsHandler, session); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java index b58937b540..7e7c9f1b18 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java @@ -66,7 +66,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); assertEquals("application/javascript;charset=UTF-8", this.response.getHeaders().getContentType().toString()); @@ -92,7 +92,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests JsonpPollingTransportHandler transportHandler = new JsonpPollingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -114,7 +114,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -128,7 +128,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); @@ -150,7 +150,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests EventSourceTransportHandler transportHandler = new EventSourceTransportHandler(); transportHandler.setSockJsServiceConfiguration(this.sockJsConfig); - StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler); + StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java index 74abb4fb31..ed8c9b28a4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -71,7 +72,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes @Override protected TestAbstractHttpSockJsSession initSockJsSession() { - return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler); + return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler, null); } @Test @@ -126,8 +127,10 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes private boolean heartbeatScheduled; - public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) { - super("1", config, handler); + public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, + Map attributes) { + + super("1", config, handler, attributes); } public boolean wasCacheFlushed() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java index d993960d39..f4a76713bb 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSessionTests.java @@ -45,7 +45,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTestsemptyMap()); } @Test @@ -102,7 +103,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTestsemptyMap()); String msg1 = "message 1"; String msg2 = "message 2"; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java index ac4fefeb29..6269883e2b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestHttpSockJsSession.java @@ -19,6 +19,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; @@ -45,8 +46,10 @@ public class TestHttpSockJsSession extends AbstractHttpSockJsSession { private String subProtocol; - public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 4aad2c6336..1cee547aab 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -18,9 +18,11 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.URI; import java.security.Principal; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; @@ -32,6 +34,8 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; */ public class TestSockJsSession extends AbstractSockJsSession { + private URI uri; + private HttpHeaders headers; private Principal principal; @@ -55,11 +59,22 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; - public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, + WebSocketHandler wsHandler, Map attributes) { + + super(sessionId, config, wsHandler, attributes); } + public void setUri(URI uri) { + this.uri = uri; + } + + @Override + public URI getUri() { + return this.uri; + } + @Override public HttpHeaders getHandshakeHeaders() { return this.headers; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java index 003f2503c0..f6ff6e5043 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -53,7 +54,8 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession @Override protected TestWebSocketServerSockJsSession initSockJsSession() { - return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler); + return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler, + Collections.emptyMap()); } @Test @@ -132,8 +134,10 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession private final List heartbeatSchedulingEvents = new ArrayList<>(); - public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) { - super("1", config, handler); + public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, + Map attributes) { + + super("1", config, handler, attributes); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 53d78aa1ea..6f16888e68 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -21,7 +21,9 @@ import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; @@ -39,6 +41,8 @@ public class TestWebSocketSession implements WebSocketSession { private URI uri; + private Map attributes = new HashMap(); + private Principal principal; private InetSocketAddress localAddress; @@ -106,6 +110,21 @@ public class TestWebSocketSession implements WebSocketSession { this.headers = headers; } + /** + * @param attributes the attributes to set + */ + public void setHandshakeAttributes(Map attributes) { + this.attributes = attributes; + } + + /** + * @return the attributes + */ + @Override + public Map getHandshakeAttributes() { + return this.attributes; + } + /** * @return the principal */