From 4e82416ba97abc804bd805dc4e9e10f5b3c685e7 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 25 Nov 2013 18:02:35 -0500 Subject: [PATCH] Add SubProtocolCapable interface The addition of SubProtocolCapable simplifies configuration since it is no longer necessary to explicitly configure DefaultHandshakeHandler with a list of supported sub-protocols. We will not also check if the WebSocketHandler to use for the WebSocket request is an instance of SubProtocolCapable and obtain the list of sub-protocols that way. The provided SubProtocolWebSocketHandler does implement this interface. Issue: SPR-11111 --- .../SubProtocolWebSocketHandler.java | 14 ++-- .../config/WebMvcStompEndpointRegistry.java | 3 +- ...MvcStompWebSocketEndpointRegistration.java | 30 +++----- .../server/DefaultHandshakeHandler.java | 69 +++++++++++++++--- .../socket/support/SubProtocolCapable.java | 19 +++++ .../WebMvcStompEndpointRegistrationTests.java | 7 +- .../server/DefaultHandshakeHandlerTests.java | 71 +++++++++++++++++-- 7 files changed, 161 insertions(+), 52 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/support/SubProtocolCapable.java diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 1bdca79d5a..bd521bbbe5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -16,12 +16,7 @@ package org.springframework.web.socket.messaging; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.TreeMap; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; @@ -37,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.support.SubProtocolCapable; /** @@ -55,7 +51,7 @@ import org.springframework.web.socket.WebSocketSession; * * @since 4.0 */ -public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHandler { +public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocketHandler, MessageHandler { private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); @@ -136,8 +132,8 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan /** * Return all supported protocols. */ - public Set getSupportedProtocols() { - return this.protocolHandlers.keySet(); + public List getSubProtocols() { + return new ArrayList(this.protocolHandlers.keySet()); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java index 150d6696b6..9144a67978 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistry.java @@ -83,10 +83,9 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { public StompWebSocketEndpointRegistration addEndpoint(String... paths) { this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler); - Set subProtocols = this.subProtocolWebSocketHandler.getSupportedProtocols(); WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration( - paths, this.webSocketHandler, subProtocols, this.sockJsScheduler); + paths, this.webSocketHandler, this.sockJsScheduler); this.registrations.add(registration); return registration; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompWebSocketEndpointRegistration.java index 8b5cf69496..6f0c04114c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/config/WebMvcStompWebSocketEndpointRegistration.java @@ -46,8 +46,6 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE private final WebSocketHandler webSocketHandler; - private final String[] subProtocols; - private final TaskScheduler sockJsTaskScheduler; private HandshakeHandler handshakeHandler; @@ -56,28 +54,14 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE public WebMvcStompWebSocketEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler, - Set subProtocols, TaskScheduler sockJsTaskScheduler) { + TaskScheduler sockJsTaskScheduler) { Assert.notEmpty(paths, "No paths specified"); Assert.notNull(webSocketHandler, "'webSocketHandler' is required"); - Assert.notNull(subProtocols, "'subProtocols' is required"); this.paths = paths; this.webSocketHandler = webSocketHandler; - this.subProtocols = subProtocols.toArray(new String[subProtocols.size()]); this.sockJsTaskScheduler = sockJsTaskScheduler; - - this.handshakeHandler = new DefaultHandshakeHandler(); - updateHandshakeHandler(); - } - - private void updateHandshakeHandler() { - if (handshakeHandler instanceof DefaultHandshakeHandler) { - DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handshakeHandler; - if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) { - defaultHandshakeHandler.setSupportedProtocols(this.subProtocols); - } - } } /** @@ -87,7 +71,6 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { Assert.notNull(handshakeHandler, "'handshakeHandler' must not be null"); this.handshakeHandler = handshakeHandler; - updateHandshakeHandler(); return this; } @@ -97,8 +80,10 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @Override public SockJsServiceRegistration withSockJS() { this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler); - WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); - this.registration.setTransportHandlerOverrides(transportHandler); + if (this.handshakeHandler != null) { + WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); + this.registration.setTransportHandlerOverrides(transportHandler); + } return this.registration; } @@ -114,8 +99,9 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE } else { for (String path : this.paths) { - WebSocketHttpRequestHandler handler = - new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler); + WebSocketHttpRequestHandler handler = (this.handshakeHandler != null) ? + new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler) : + new WebSocketHttpRequestHandler(this.webSocketHandler); mappings.add(handler, path); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index a5a8338b1a..c886c50574 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,8 +32,10 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.support.SubProtocolCapable; import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.support.WebSocketHandlerDecorator; import org.springframework.web.socket.support.WebSocketHttpHeaders; /** @@ -122,10 +124,16 @@ public class DefaultHandshakeHandler implements HandshakeHandler { } /** - * Use this property to configure a list of sub-protocols that are supported. - * The first protocol that matches what the client requested is selected. - * If no protocol matches or this property is not configured, then the - * response will not contain a Sec-WebSocket-Protocol header. + * Use this property to configure the list of supported sub-protocols. + * The first configured sub-protocol that matches a client-requested sub-protocol + * is accepted. If there are no matches the response will not contain a + * {@literal Sec-WebSocket-Protocol} header. + *

+ * Note that if the WebSocketHandler passed in at runtime is an instance of + * {@link SubProtocolCapable} then there is not need to explicitly configure + * this property. That is certainly the case with the built-in STOMP over + * WebSocket support. Therefore this property should be configured explicitly + * only if the WebSocketHandler does not implement {@code SubProtocolCapable}. */ public void setSupportedProtocols(String... protocols) { this.supportedProtocols.clear(); @@ -187,7 +195,10 @@ public class DefaultHandshakeHandler implements HandshakeHandler { "Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex); } - String subProtocol = selectProtocol(headers.getSecWebSocketProtocol()); + String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler); + if (logger.isDebugEnabled()) { + logger.debug("Selected sub-protocol: '" + subProtocol + "'"); + } List requested = headers.getSecWebSocketExtensions(); List supported = this.requestUpgradeStrategy.getSupportedExtensions(request); @@ -246,17 +257,32 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return true; } - protected String selectProtocol(List requestedProtocols) { + /** + * Perform the sub-protocol negotiation based on requested and supported sub-protocols. + * For the list of supported sub-protocols, this method first checks if the target + * WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any + * sub-protocols have been explicitly configured with + * {@link #setSupportedProtocols(String...)}. + * + * @param requestedProtocols the requested sub-protocols + * @param webSocketHandler the WebSocketHandler that will be used + * @return the selected protocols or {@code null} + * + * @see #determineHandlerSupportedProtocols(org.springframework.web.socket.WebSocketHandler) + */ + protected String selectProtocol(List requestedProtocols, WebSocketHandler webSocketHandler) { if (requestedProtocols != null) { + List handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler); if (logger.isDebugEnabled()) { - logger.debug("Requested sub-protocol(s): " + requestedProtocols - + ", supported sub-protocol(s): " + this.supportedProtocols); + logger.debug("Requested sub-protocol(s): " + requestedProtocols + + ", WebSocketHandler supported sub-protocol(s): " + handlerProtocols + + ", configured sub-protocol(s): " + this.supportedProtocols); } for (String protocol : requestedProtocols) { + if (handlerProtocols.contains(protocol.toLowerCase())) { + return protocol; + } if (this.supportedProtocols.contains(protocol.toLowerCase())) { - if (logger.isDebugEnabled()) { - logger.debug("Selected sub-protocol: '" + protocol + "'"); - } return protocol; } } @@ -264,6 +290,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } + /** + * Determine the sub-protocols supported by the given WebSocketHandler by checking + * whether it is an instance of {@link SubProtocolCapable}. + * + * @param handler the handler to check + * @return a list of supported protocols or an empty list + */ + protected final List determineHandlerSupportedProtocols(WebSocketHandler handler) { + List subProtocols = null; + if (handler instanceof SubProtocolCapable) { + subProtocols = ((SubProtocolCapable) handler).getSubProtocols(); + } + else if (handler instanceof WebSocketHandlerDecorator) { + WebSocketHandler lastHandler = ((WebSocketHandlerDecorator) handler).getLastHandler(); + if (lastHandler instanceof SubProtocolCapable) { + subProtocols = ((SubProtocolCapable) lastHandler).getSubProtocols();; + } + } + return (subProtocols != null) ? subProtocols : Collections.emptyList(); + } + /** * Filter the list of requested WebSocket extensions. *

diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/SubProtocolCapable.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/SubProtocolCapable.java new file mode 100644 index 0000000000..8fc69230f0 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/SubProtocolCapable.java @@ -0,0 +1,19 @@ +package org.springframework.web.socket.support; + +import java.util.List; + +/** + * An interface for WebSocket handlers that support sub-protocols as defined in RFC 6455. + * + * @author Rossen Stoyanchev + * @since 4.0 + * + * @see RFC-6455 section 1.9 + */ +public interface SubProtocolCapable { + + /** + * Return the list of supported sub-protocols. + */ + List getSubProtocols(); +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java index 0e1283e7fb..36b1a06260 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/config/WebMvcStompEndpointRegistrationTests.java @@ -17,7 +17,6 @@ package org.springframework.web.socket.messaging.config; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -62,7 +61,7 @@ public class WebMvcStompEndpointRegistrationTests { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration( - new String[] {"/foo"}, this.wsHandler, Collections.emptySet(), this.scheduler); + new String[] {"/foo"}, this.wsHandler, this.scheduler); MultiValueMap mappings = registration.getMappings(); assertEquals(1, mappings.size()); @@ -78,7 +77,7 @@ public class WebMvcStompEndpointRegistrationTests { DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration( - new String[] {"/foo"}, this.wsHandler, Collections.emptySet(), this.scheduler); + new String[] {"/foo"}, this.wsHandler, this.scheduler); registration.setHandshakeHandler(handshakeHandler); @@ -99,7 +98,7 @@ public class WebMvcStompEndpointRegistrationTests { DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration( - new String[] {"/foo"}, this.wsHandler, Collections.emptySet(), this.scheduler); + new String[] {"/foo"}, this.wsHandler, this.scheduler); registration.setHandshakeHandler(handshakeHandler); registration.withSockJS(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index 5a96e5ec39..6e1c586dce 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -16,7 +16,9 @@ package org.springframework.web.socket.server; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import org.junit.Before; @@ -24,11 +26,13 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.support.SubProtocolCapable; import org.springframework.web.socket.support.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; import org.springframework.web.socket.support.WebSocketHttpHeaders; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; @@ -53,15 +57,15 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { @Test - public void selectSubProtocol() throws Exception { + public void supportedSubProtocols() throws Exception { this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); - this.servletRequest.setMethod("GET"); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); headers.setUpgrade("WebSocket"); headers.setConnection("Upgrade"); headers.setSecWebSocketVersion("13"); @@ -70,11 +74,70 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { WebSocketHandler handler = new TextWebSocketHandlerAdapter(); Map attributes = Collections.emptyMap(); - this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", Collections.emptyList(), handler, attributes); } + @Test + public void subProtocolCapableHandler() throws Exception { + + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[]{"13"}); + + this.servletRequest.setMethod("GET"); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketProtocol("v11.stomp"); + + WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp"); + Map attributes = Collections.emptyMap(); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); + + verify(this.upgradeStrategy).upgrade(this.request, this.response, + "v11.stomp", Collections.emptyList(), handler, attributes); + } + + @Test + public void subProtocolCapableHandlerNoMatch() throws Exception { + + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[]{"13"}); + + this.servletRequest.setMethod("GET"); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketProtocol("v10.stomp"); + + WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp"); + Map attributes = Collections.emptyMap(); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); + + verify(this.upgradeStrategy).upgrade(this.request, this.response, + null, Collections.emptyList(), handler, attributes); + } + + + private static class SubProtocolCapableHandler extends TextWebSocketHandlerAdapter implements SubProtocolCapable { + + private final List subProtocols; + + + private SubProtocolCapableHandler(String... subProtocols) { + this.subProtocols = Arrays.asList(subProtocols); + } + + @Override + public List getSubProtocols() { + return this.subProtocols; + } + } + }