Add HandshakeInterceptor

A HandshakeInterceptor can be used to intercept WebSocket handshakes
(or SockJS requests where a new session is created) in order to
inspect the request and response before and after the handshake
including the ability to pass attributes to the WebSocketHandler,
which the hander can access through
WebSocketSession.getHandshakeAttributes()

An HttpSessionHandshakeInterceptor is available that can copy
attributes from the HTTP session to make them available to the
WebSocket session.

Issue: SPR-10624
This commit is contained in:
Rossen Stoyanchev
2013-08-13 17:23:33 -04:00
parent 9925d8385f
commit 319f18dddf
46 changed files with 903 additions and 170 deletions

View File

@@ -44,7 +44,7 @@ public class JettyWebSocketHandlerAdapterTests {
public void setup() {
this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new JettyWebSocketSession(null);
this.webSocketSession = new JettyWebSocketSession(null, null);
this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
}

View File

@@ -50,7 +50,7 @@ public class StandardWebSocketHandlerAdapterTests {
public void setup() {
this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new StandardWebSocketSession(null, null, null);
this.webSocketSession = new StandardWebSocketSession(null, null, null, null);
this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
}

View File

@@ -113,7 +113,7 @@ public class JettyWebSocketClientTests {
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
}
JettyWebSocketSession session = new JettyWebSocketSession(null);
JettyWebSocketSession session = new JettyWebSocketSession(null, null);
return new JettyWebSocketHandlerAdapter(webSocketHandler, session);
}
});

View File

@@ -16,6 +16,9 @@
package org.springframework.web.socket.server;
import java.util.Collections;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
@@ -62,10 +65,11 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.request.getHeaders().setSecWebSocketProtocol("STOMP");
WebSocketHandler handler = new TextWebSocketHandlerAdapter();
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler);
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(request, response, "STOMP", handler);
verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP", handler, attributes);
}
}

View File

@@ -0,0 +1,101 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link HandshakeInterceptorChain}.
*
* @author Rossen Stoyanchev
*/
public class HandshakeInterceptorChainTests extends AbstractHttpRequestTests {
private HandshakeInterceptor i1;
private HandshakeInterceptor i2;
private HandshakeInterceptor i3;
private List<HandshakeInterceptor> interceptors;
private WebSocketHandler wsHandler;
private Map<String, Object> attributes;
@Before
public void setup() {
i1 = mock(HandshakeInterceptor.class);
i2 = mock(HandshakeInterceptor.class);
i3 = mock(HandshakeInterceptor.class);
interceptors = Arrays.asList(i1, i2, i3);
wsHandler = mock(WebSocketHandler.class);
attributes = new HashMap<String, Object>();
}
@Test
public void success() throws Exception {
when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true);
when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true);
when(i3.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true);
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler);
chain.applyBeforeHandshake(request, response, attributes);
verify(i1).beforeHandshake(request, response, wsHandler, attributes);
verify(i2).beforeHandshake(request, response, wsHandler, attributes);
verify(i3).beforeHandshake(request, response, wsHandler, attributes);
verifyNoMoreInteractions(i1, i2, i3);
}
@Test
public void applyBeforeHandshakeWithFalseReturnValue() throws Exception {
when(i1.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(true);
when(i2.beforeHandshake(request, response, wsHandler, attributes)).thenReturn(false);
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler);
chain.applyBeforeHandshake(request, response, attributes);
verify(i1).beforeHandshake(request, response, wsHandler, attributes);
verify(i1).afterHandshake(request, response, wsHandler, null);
verify(i2).beforeHandshake(request, response, wsHandler, attributes);
verifyNoMoreInteractions(i1, i2, i3);
}
@Test
public void applyAfterHandshakeOnly() {
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(interceptors, wsHandler);
chain.applyAfterHandshake(request, response, null);
verifyNoMoreInteractions(i1, i2, i3);
}
}

View File

@@ -0,0 +1,86 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import static org.junit.Assert.*;
/**
* Test fixture for {@link HttpSessionHandshakeInterceptor}.
*
* @author Rossen Stoyanchev
*/
public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test
public void copyAllAttributes() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.beforeHandshake(request, response, wsHandler, attributes);
assertEquals(2, attributes.size());
assertEquals("bar", attributes.get("foo"));
assertEquals("baz", attributes.get("bar"));
}
@Test
public void copySelectedAttributes() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
Set<String> names = Collections.singleton("foo");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names);
interceptor.beforeHandshake(request, response, wsHandler, attributes);
assertEquals(1, attributes.size());
assertEquals("bar", attributes.get("foo"));
}
@Test
public void doNotCauseSessionCreation() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.beforeHandshake(request, response, wsHandler, attributes);
assertNull(this.servletRequest.getSession(false));
}
}

View File

@@ -17,6 +17,7 @@
package org.springframework.web.socket.sockjs.transport.handler;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -70,10 +71,11 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
MockitoAnnotations.initMocks(this);
this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler);
Map<String, Object> attributes = Collections.emptyMap();
this.session = new TestSockJsSession(sessionId, new StubSockJsServiceConfig(), this.wsHandler, attributes);
when(this.xhrHandler.getTransportType()).thenReturn(TransportType.XHR);
when(this.xhrHandler.createSession(sessionId, this.wsHandler)).thenReturn(this.session);
when(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).thenReturn(this.session);
when(this.xhrSendHandler.getTransportType()).thenReturn(TransportType.XHR_SEND);
this.service = new DefaultSockJsService(this.taskScheduler,

View File

@@ -107,7 +107,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest
this.servletRequest.setContent("[\"x\"]".getBytes("UTF-8"));
WebSocketHandler wsHandler = mock(WebSocketHandler.class);
TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler);
TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler, null);
session.delegateConnectionEstablished();
doThrow(new Exception()).when(wsHandler).handleMessage(session, new TextMessage("x"));
@@ -127,7 +127,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest
private void handleRequest(AbstractHttpReceivingTransportHandler transportHandler) throws Exception {
WebSocketHandler wsHandler = mock(WebSocketHandler.class);
AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler);
AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null);
transportHandler.setSockJsServiceConfiguration(new StubSockJsServiceConfig());
transportHandler.handleRequest(this.request, this.response, wsHandler, session);
@@ -141,7 +141,7 @@ public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTest
resetResponse();
WebSocketHandler wsHandler = mock(WebSocketHandler.class);
AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler);
AbstractSockJsSession session = new TestHttpSockJsSession("1", new StubSockJsServiceConfig(), wsHandler, null);
new XhrReceivingTransportHandler().handleRequest(this.request, this.response, wsHandler, session);

View File

@@ -66,7 +66,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler();
transportHandler.setSockJsServiceConfiguration(this.sockJsConfig);
AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler);
AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session);
assertEquals("application/javascript;charset=UTF-8", this.response.getHeaders().getContentType().toString());
@@ -92,7 +92,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
JsonpPollingTransportHandler transportHandler = new JsonpPollingTransportHandler();
transportHandler.setSockJsServiceConfiguration(this.sockJsConfig);
PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler);
PollingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session);
@@ -114,7 +114,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler();
transportHandler.setSockJsServiceConfiguration(this.sockJsConfig);
AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler);
AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session);
@@ -128,7 +128,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler();
transportHandler.setSockJsServiceConfiguration(this.sockJsConfig);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session);
@@ -150,7 +150,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
EventSourceTransportHandler transportHandler = new EventSourceTransportHandler();
transportHandler.setSockJsServiceConfiguration(this.sockJsConfig);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
transportHandler.handleRequest(this.request, this.response, this.webSocketHandler, session);

View File

@@ -17,6 +17,7 @@
package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
@@ -71,7 +72,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes
@Override
protected TestAbstractHttpSockJsSession initSockJsSession() {
return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler);
return new TestAbstractHttpSockJsSession(this.sockJsConfig, this.webSocketHandler, null);
}
@Test
@@ -126,8 +127,10 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes
private boolean heartbeatScheduled;
public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) {
super("1", config, handler);
public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler,
Map<String, Object> attributes) {
super("1", config, handler, attributes);
}
public boolean wasCacheFlushed() {

View File

@@ -45,7 +45,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTests<T
@Override
protected TestSockJsSession initSockJsSession() {
return new TestSockJsSession("1", this.sockJsConfig, this.webSocketHandler);
return new TestSockJsSession("1", this.sockJsConfig, this.webSocketHandler,
Collections.<String, Object>emptyMap());
}
@Test
@@ -102,7 +103,8 @@ public class AbstractSockJsSessionTests extends BaseAbstractSockJsSessionTests<T
public void delegateMessagesWithErrorAndConnectionClosing() throws Exception {
WebSocketHandler wsHandler = new ExceptionWebSocketHandlerDecorator(this.webSocketHandler);
TestSockJsSession sockJsSession = new TestSockJsSession("1", this.sockJsConfig, wsHandler);
TestSockJsSession sockJsSession = new TestSockJsSession("1", this.sockJsConfig,
wsHandler, Collections.<String, Object>emptyMap());
String msg1 = "message 1";
String msg2 = "message 2";

View File

@@ -19,6 +19,7 @@ package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
@@ -45,8 +46,10 @@ public class TestHttpSockJsSession extends AbstractHttpSockJsSession {
private String subProtocol;
public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) {
super(sessionId, config, handler);
public TestHttpSockJsSession(String sessionId, SockJsServiceConfig config,
WebSocketHandler wsHandler, Map<String, Object> attributes) {
super(sessionId, config, wsHandler, attributes);
}
@Override

View File

@@ -18,9 +18,11 @@ package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
@@ -32,6 +34,8 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
*/
public class TestSockJsSession extends AbstractSockJsSession {
private URI uri;
private HttpHeaders headers;
private Principal principal;
@@ -55,11 +59,22 @@ public class TestSockJsSession extends AbstractSockJsSession {
private String subProtocol;
public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) {
super(sessionId, config, handler);
public TestSockJsSession(String sessionId, SockJsServiceConfig config,
WebSocketHandler wsHandler, Map<String, Object> attributes) {
super(sessionId, config, wsHandler, attributes);
}
public void setUri(URI uri) {
this.uri = uri;
}
@Override
public URI getUri() {
return this.uri;
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;

View File

@@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
@@ -53,7 +54,8 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
@Override
protected TestWebSocketServerSockJsSession initSockJsSession() {
return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler);
return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler,
Collections.<String, Object>emptyMap());
}
@Test
@@ -132,8 +134,10 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
private final List<String> heartbeatSchedulingEvents = new ArrayList<>();
public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler) {
super("1", config, handler);
public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler,
Map<String, Object> attributes) {
super("1", config, handler, attributes);
}
@Override

View File

@@ -21,7 +21,9 @@ import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
@@ -39,6 +41,8 @@ public class TestWebSocketSession implements WebSocketSession {
private URI uri;
private Map<String, Object> attributes = new HashMap<String, Object>();
private Principal principal;
private InetSocketAddress localAddress;
@@ -106,6 +110,21 @@ public class TestWebSocketSession implements WebSocketSession {
this.headers = headers;
}
/**
* @param attributes the attributes to set
*/
public void setHandshakeAttributes(Map<String, Object> attributes) {
this.attributes = attributes;
}
/**
* @return the attributes
*/
@Override
public Map<String, Object> getHandshakeAttributes() {
return this.attributes;
}
/**
* @return the principal
*/