diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandler.java index 48fa86e485..cd9bb4727f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,14 @@ public interface TransportHandler { */ TransportType getTransportType(); + /** + * Whether the type of the given session matches the transport type of this + * {@code TransportHandler} where session id and the transport type are + * extracted from the SockJS URL. + * @since 4.3.3 + */ + boolean checkSessionType(SockJsSession session); + /** * Handle the given request and delegate messages to the provided * {@link WebSocketHandler}. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 4e527e4258..a16b46da8d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -291,6 +291,11 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem return; } } + if (!transportHandler.checkSessionType(session)) { + logger.debug("Session type does not match the transport type for the request."); + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } } if (transportType.sendsNoCacheInstruction()) { @@ -303,7 +308,10 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem } } + transportHandler.handleRequest(request, response, handler, session); + + chain.applyAfterHandshake(request, response, null); } catch (SockJsException ex) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java index 10563d9ec9..ef60a60d41 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java @@ -99,4 +99,9 @@ public abstract class AbstractHttpReceivingTransportHandler extends AbstractTran protected abstract HttpStatus getResponseStatus(); + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof AbstractHttpSockJsSession; + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java index 0034805674..9988d51d45 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java @@ -25,7 +25,9 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; +import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession; /** @@ -47,6 +49,11 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan return new MediaType("text", "event-stream", StandardCharsets.UTF_8); } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof EventSourceStreamingSockJsSession; + } + @Override public StreamingSockJsSession createSession( String sessionId, WebSocketHandler handler, Map attributes) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index 22e6925959..7a96a66cff 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -32,9 +32,11 @@ import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; +import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.session.AbstractHttpSockJsSession; +import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession; import org.springframework.web.util.JavaScriptUtils; @@ -88,6 +90,11 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle return new MediaType("text", "html", StandardCharsets.UTF_8); } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof HtmlFileStreamingSockJsSession; + } + @Override public StreamingSockJsSession createSession( String sessionId, WebSocketHandler handler, Map attributes) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java index 926c0413a5..e4f6b3d7dd 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/JsonpPollingTransportHandler.java @@ -30,6 +30,7 @@ import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; +import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.session.AbstractHttpSockJsSession; import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; @@ -53,6 +54,11 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa return new MediaType("application", "javascript", StandardCharsets.UTF_8); } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof PollingSockJsSession; + } + @Override public PollingSockJsSession createSession( String sessionId, WebSocketHandler handler, Map attributes) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java index d7316f4786..6915512068 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/WebSocketTransportHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,6 +103,11 @@ public class WebSocketTransportHandler extends AbstractTransportHandler } } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof WebSocketServerSockJsSession; + } + @Override public AbstractSockJsSession createSession(String id, WebSocketHandler handler, Map attrs) { return new WebSocketServerSockJsSession(id, getServiceConfig(), handler, attrs); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java index 95740aa599..e635f17a25 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrPollingTransportHandler.java @@ -24,8 +24,10 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; +import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.socket.sockjs.transport.session.AbstractHttpSockJsSession; import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; /** @@ -51,6 +53,11 @@ public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHand return new DefaultSockJsFrameFormat("%s\n"); } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof PollingSockJsSession; + } + @Override public PollingSockJsSession createSession( String sessionId, WebSocketHandler handler, Map attributes) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java index 6453814d6a..a29557c070 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java @@ -25,8 +25,10 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; +import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession; import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession; /** @@ -57,6 +59,11 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa return new MediaType("application", "javascript", StandardCharsets.UTF_8); } + @Override + public boolean checkSessionType(SockJsSession session) { + return session instanceof XhrStreamingSockJsSession; + } + @Override public StreamingSockJsSession createSession( String sessionId, WebSocketHandler handler, Map attributes) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index aea9b0e802..a45a3e701b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ import static org.junit.Assert.*; import static org.mockito.BDDMockito.*; /** - * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. + * Test fixture for {@link DefaultSockJsService}. * * @author Rossen Stoyanchev * @author Sebastien Deleuze @@ -239,6 +239,7 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { resetResponse(); sockJsPath = sessionUrlPrefix + "xhr_send"; setRequest("POST", sockJsPrefix + sockJsPath); + given(this.xhrSendHandler.checkSessionType(this.session)).willReturn(true); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(200, this.servletResponse.getStatus()); // session exists