diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java index 14b5402352..4229e6efb3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.config.annotation; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; /** * A contract for configuring a STOMP over WebSocket endpoint. @@ -36,4 +37,9 @@ public interface StompWebSocketEndpointRegistration { */ StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler); + /** + * Configure the HandshakeInterceptor's to use. + */ + StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors); + } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java index 8b0bc31a42..a9e24920a4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java @@ -23,11 +23,14 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import java.util.Arrays; + /** * An abstract base class class for configuring STOMP over WebSocket/SockJS endpoints. * @@ -44,6 +47,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE private HandshakeHandler handshakeHandler; + private HandshakeInterceptor[] interceptors; + private StompSockJsServiceRegistration registration; @@ -58,9 +63,6 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE this.sockJsTaskScheduler = sockJsTaskScheduler; } - /** - * Provide a custom or pre-configured {@link HandshakeHandler}. - */ @Override public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { Assert.notNull(handshakeHandler, "'handshakeHandler' must not be null"); @@ -68,12 +70,22 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE return this; } - /** - * Enable SockJS fallback options. - */ + @Override + public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) { + this.interceptors = interceptors; + return this; + } + + protected HandshakeInterceptor[] getInterceptors() { + return this.interceptors; + } + @Override public SockJsServiceRegistration withSockJS() { this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler); + if (this.interceptors != null) { + this.registration.setInterceptors(this.interceptors); + } if (this.handshakeHandler != null) { WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); this.registration.setTransportHandlerOverrides(transportHandler); @@ -93,9 +105,16 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE } else { for (String path : this.paths) { - WebSocketHttpRequestHandler handler = (this.handshakeHandler != null) ? - new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler) : - new WebSocketHttpRequestHandler(this.webSocketHandler); + WebSocketHttpRequestHandler handler; + if (this.handshakeHandler != null) { + handler = new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler); + } + else { + handler = new WebSocketHttpRequestHandler(this.webSocketHandler); + } + if (this.interceptors != null) { + handler.setHandshakeInterceptors(Arrays.asList(this.interceptors)); + } mappings.add(handler, path); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index 56a0b9823b..7cf381b6bb 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -29,7 +29,9 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; +import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; @@ -38,6 +40,8 @@ import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsServ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; /** @@ -73,12 +77,15 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { } @Test - public void customHandshakeHandler() { + public void handshakeHandlerAndInterceptors() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + registration.setHandshakeHandler(handshakeHandler); + registration.addInterceptors(interceptor); MultiValueMap mappings = registration.getMappings(); assertEquals(1, mappings.size()); @@ -89,15 +96,19 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); assertNotNull(requestHandler.getWebSocketHandler()); assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); + assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors()); } @Test - public void customHandshakeHandlerPassedToSockJsService() { + public void handshakeHandlerAndInterceptorsWithSockJsService() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + registration.setHandshakeHandler(handshakeHandler); + registration.addInterceptors(interceptor); registration.withSockJS(); MultiValueMap mappings = registration.getMappings(); @@ -115,6 +126,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { Map handlers = sockJsService.getTransportHandlers(); WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); assertSame(handshakeHandler, transportHandler.getHandshakeHandler()); + assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors()); } }