Add an option to set an Origin whitelist for Websocket and SockJS

This commit introduces a new OriginHandshakeInterceptor. It filters
Origin header value against a list of allowed origins.

AbstractSockJsService as been modified to:
 - Reject CORS requests with forbidden origins
 - Disable transport types that does not support CORS when an origin
   check is required
 - Use the Origin request header value instead of "*" for
   Access-Control-Allow-Origin response header value
   (mandatory when  Access-Control-Allow-Credentials=true)
 - Return CORS header only if the request contains an Origin header

It is possible to configure easily this behavior thanks to JavaConfig API
WebSocketHandlerRegistration#addAllowedOrigins(String...) and
StompWebSocketEndpointRegistration#addAllowedOrigins(String...).
It is also possible to configure it using the websocket XML namespace.

Please notice that this commit does not change the default behavior:
cross origin requests are still enabled by default.

Issues: SPR-12226
This commit is contained in:
Sebastien Deleuze
2014-10-25 23:30:40 +02:00
parent 28a3cd50aa
commit 743356fa21
26 changed files with 870 additions and 70 deletions

View File

@@ -54,6 +54,10 @@ public abstract class AbstractHttpRequestTests {
this.servletRequest.setRequestURI(requestUri);
}
protected void setOrigin(String origin) {
this.servletRequest.addHeader("Origin", origin);
}
protected void resetRequestAndResponse() {
resetRequest();
resetResponse();

View File

@@ -18,11 +18,13 @@ package org.springframework.web.socket.config;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import static org.junit.Assert.assertEquals;
import org.junit.Before;
import org.junit.Test;
@@ -45,6 +47,7 @@ import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec;
@@ -103,6 +106,7 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
}
else {
assertThat(shm.getUrlMap().keySet(), contains("/test"));
@@ -112,6 +116,7 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
}
}
}
@@ -135,7 +140,8 @@ public class HandlersBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
List<HandshakeInterceptor> interceptors = handler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/test");
assertNotNull(handler);
@@ -144,8 +150,8 @@ public class HandlersBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
interceptors = handler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
}
@Test
@@ -222,6 +228,10 @@ public class HandlersBeanDefinitionParserTests {
assertEquals(1024, transportService.getHttpMessageCacheSize());
assertEquals(20, transportService.getHeartbeatTime());
assertEquals(TestMessageCodec.class, transportService.getMessageCodec().getClass());
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
}
private void loadBeanDefinitions(String fileName) {

View File

@@ -68,6 +68,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportType;
@@ -115,7 +116,8 @@ public class MessageBrokerBeanDefinitionParserTests {
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof TestHandshakeHandler);
List<HandshakeInterceptor> interceptors = wsHttpRequestHandler.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
WebSocketSession session = new TestWebSocketSession("id");
wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session);
@@ -158,7 +160,9 @@ public class MessageBrokerBeanDefinitionParserTests {
assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy());
interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins());
UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class);
assertNotNull(userSessionRegistry);

View File

@@ -29,6 +29,7 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
@@ -70,19 +71,60 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty());
assertEquals(Arrays.asList("/foo"), entry.getValue());
}
@Test
public void handshakeHandlerAndInterceptors() {
public void allowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins("http://mydomain.com");
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertEquals(1, requestHandler.getHandshakeInterceptors().size());
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test
public void allowedOriginsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
String origin = "http://mydomain.com";
registration.setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOrigins(origin);
mappings = registration.getMappings();
assertEquals(1, mappings.size());
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService());
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
}
@Test
public void handshakeHandlerAndInterceptor() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
@@ -97,16 +139,38 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
}
@Test
public void handshakeHandlerAndInterceptorsWithSockJsService() {
public void handshakeHandlerAndInterceptorWithAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
String origin = "http://mydomain.com";
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertEquals(Arrays.asList("/foo"), entry.getValue());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(2, requestHandler.getHandshakeInterceptors().size());
assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
}
@Test
public void handshakeHandlerInterceptorWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
registration.withSockJS();
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
@@ -126,4 +190,37 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
}
@Test
public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
String origin = "http://mydomain.com";
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertEquals(Arrays.asList("/foo/**"), entry.getValue());
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService();
assertNotNull(sockJsService);
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(2, sockJsService.getHandshakeInterceptors().size());
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class,
sockJsService.getHandshakeInterceptors().get(1).getClass());
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins());
}
}

View File

@@ -29,6 +29,7 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService;
@@ -68,10 +69,12 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertEquals(0, m1.interceptors.length);
Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
assertEquals(0, m2.interceptors.length);
}
@Test
@@ -90,12 +93,31 @@ public class WebSocketHandlerRegistrationTests {
assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors);
}
@Test
public void interceptorsWithAllowedOrigins() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("http://mydomain1.com");
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test
public void interceptorsPassedToSockJsRegistration() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).withSockJS();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
.setAllowedOrigins("http://mydomain1.com").withSockJS();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
@@ -104,7 +126,11 @@ public class WebSocketHandlerRegistrationTests {
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo/**", mapping.path);
assertNotNull(mapping.sockJsService);
assertEquals(Arrays.asList(interceptor), mapping.sockJsService.getHandshakeInterceptors());
assertEquals(Arrays.asList("http://mydomain1.com"),
mapping.sockJsService.getAllowedOrigins());
List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
assertEquals(interceptor, interceptors.get(0));
assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
}
@Test

View File

@@ -0,0 +1,106 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import static org.junit.Assert.*;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpStatus;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
/**
* Test fixture for {@link OriginHandshakeInterceptor}.
*
* @author Sebastien Deleuze
*/
public class AllowedOriginsInterceptorTests extends AbstractHttpRequestTests {
@Test
public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originValueNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originListMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void originListNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void noOriginNoMatchWithNullHostileCollection() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
Set<String> allowedOrigins = new ConcurrentSkipListSet<String>();
allowedOrigins.add("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void noOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
}

View File

@@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.support;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import static org.junit.Assert.assertEquals;
import org.junit.Before;
import org.junit.Test;
@@ -40,9 +43,12 @@ import static org.mockito.BDDMockito.*;
* Test fixture for {@link AbstractSockJsService}.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class SockJsServiceTests extends AbstractHttpRequestTests {
private static final List<String> origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com");
private TestSockJsService service;
private WebSocketHandler handler;
@@ -80,10 +86,10 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
@@ -97,6 +103,47 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
body = this.servletResponse.getContentAsString();
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}",
body.substring(body.indexOf(',')));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12226
public void handleInfoGetWithOrigin() throws Exception {
setOrigin("http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}",
body.substring(body.indexOf(',')));
this.service.setAllowedOrigins(null);
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-11443
@@ -129,7 +176,60 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12226
public void handleInfoOptionsWithOrigin() throws Exception {
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(null);
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,7 +16,9 @@
package org.springframework.web.socket.sockjs.transport.handler;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.Before;
@@ -27,6 +29,8 @@ import org.mockito.MockitoAnnotations;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.sockjs.transport.SockJsSessionFactory;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
@@ -41,6 +45,7 @@ import static org.mockito.BDDMockito.*;
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/
public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
@@ -50,11 +55,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
private static final String sessionUrlPrefix = "/server1/" + sessionId + "/";
private static final List<String> origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com");
@Mock private SessionCreatingTransportHandler xhrHandler;
@Mock private TransportHandler xhrSendHandler;
@Mock private SessionCreatingTransportHandler jsonpHandler;
@Mock private TransportHandler jsonpSendHandler;
@Mock private HandshakeTransportHandler wsTransportHandler;
@Mock private WebSocketHandler wsHandler;
@Mock private TaskScheduler taskScheduler;
@@ -75,6 +88,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
given(this.xhrHandler.getTransportType()).willReturn(TransportType.XHR);
given(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session);
given(this.xhrSendHandler.getTransportType()).willReturn(TransportType.XHR_SEND);
given(this.jsonpHandler.getTransportType()).willReturn(TransportType.JSONP);
given(this.jsonpHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session);
given(this.jsonpSendHandler.getTransportType()).willReturn(TransportType.JSONP_SEND);
given(this.wsTransportHandler.getTransportType()).willReturn(TransportType.WEBSOCKET);
this.service = new TransportHandlingSockJsService(this.taskScheduler, this.xhrHandler, this.xhrSendHandler);
}
@@ -126,10 +143,47 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
verify(taskScheduler).scheduleAtFixedRate(any(Runnable.class), eq(service.getDisconnectDelay()));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.response.getHeaders().getCacheControl());
assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowNullOrigin() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(null);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowedOriginsMatch() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
setOrigin(origins.get(0));
this.service.setAllowedOrigins(origins);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(200, this.servletResponse.getStatus());
assertEquals(origins.get(0), this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
public void handleTransportRequestXhrAllowedOriginsNoMatch() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain3.com");
this.service.setAllowedOrigins(origins);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test
public void handleTransportRequestXhrOptions() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
@@ -137,9 +191,22 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(204, this.servletResponse.getStatus());
assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertEquals("OPTIONS, POST", this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
}
@Test // SPR-12226
public void handleTransportRequestXhrOptionsAllowNullOrigin() throws Exception {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("OPTIONS", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(null);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
}
@Test
@@ -176,8 +243,56 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
verify(this.xhrSendHandler).handleRequest(this.request, this.response, this.wsHandler, this.session);
}
@Test
public void handleTransportRequestJsonp() throws Exception {
TransportHandlingSockJsService jsonpService = new TransportHandlingSockJsService(this.taskScheduler, this.jsonpHandler, this.jsonpSendHandler);
String sockJsPath = sessionUrlPrefix+ "jsonp";
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(null);
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
}
@Test
public void handleTransportRequestWebsocket() throws Exception {
TransportHandlingSockJsService wsService = new TransportHandlingSockJsService(this.taskScheduler, this.wsTransportHandler);
String sockJsPath = "/websocket";
setRequest("GET", sockJsPrefix + sockJsPath);
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain2.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
}
interface SessionCreatingTransportHandler extends TransportHandler, SockJsSessionFactory {
}
interface HandshakeTransportHandler extends TransportHandler, HandshakeHandler {
}
}