Change SockJS and Websocket default allowedOrigins to same origin

This commit adds support for a same origin check that compares
Origin header to Host header. It also changes the default setting
from all origins allowed to only same origin allowed.

Issues: SPR-12697, SPR-12685
(cherry picked from commit 6062e15)
This commit is contained in:
Sebastien Deleuze
2015-02-13 16:56:09 +01:00
committed by Juergen Hoeller
parent cc78d40c6b
commit 23fa37b08b
20 changed files with 363 additions and 123 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -106,7 +106,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
}
else {
assertThat(shm.getUrlMap().keySet(), contains("/test"));
@@ -116,7 +117,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty());
assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
}
}
}
@@ -196,7 +198,7 @@ public class HandlersBeanDefinitionParserTests {
assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass());
List<HandshakeInterceptor> interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
}
@Test

View File

@@ -71,7 +71,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty());
assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size());
assertEquals(Arrays.asList("/foo"), entry.getValue());
}
@@ -80,7 +80,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins("http://mydomain.com");
registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
@@ -90,10 +90,18 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test(expected = IllegalArgumentException.class)
public void noAllowedOrigin() {
@Test
public void sameOrigin() {
WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertEquals(1, requestHandler.getHandshakeInterceptors().size());
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
}
@Test
@@ -158,7 +166,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors());
assertEquals(2, requestHandler.getHandshakeInterceptors().size());
assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
}
@Test
@@ -210,7 +220,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
assertEquals(2, sockJsService.getHandshakeInterceptors().size());
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass());
}
@Test

View File

@@ -69,12 +69,14 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertEquals(0, m1.interceptors.length);
assertEquals(1, m1.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
assertEquals(0, m2.interceptors.length);
assertEquals(1, m2.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
}
@Test
@@ -90,12 +92,27 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test(expected = IllegalArgumentException.class)
public void noAllowedOrigin() {
this.registration.addHandler(Mockito.mock(WebSocketHandler.class), "/foo").setAllowedOrigins();
@Test
public void emptyAllowedOrigin() {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
}
@Test

View File

@@ -39,20 +39,22 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(null);
new OriginHandshakeInterceptor(null);
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("domain.com"));
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
}
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
}
@Test
public void validAllowedOrigins() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
@@ -60,8 +62,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -71,8 +72,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -82,8 +82,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -93,8 +92,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -123,4 +121,26 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void sameOriginMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void sameOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
}

View File

@@ -110,6 +110,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660
public void handleInfoGetWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
@@ -135,6 +136,12 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-11443
@@ -186,6 +193,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660
public void handleInfoOptionsWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
@@ -216,6 +224,16 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-12283

View File

@@ -122,19 +122,15 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertSame(xhrHandler, handlers.get(xhrHandler.getTransportType()));
}
@Test
public void defaultAllowedOrigin() {
assertThat(this.service.getAllowedOrigins(), Matchers.contains("*"));
}
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
this.service.setAllowedOrigins(null);
}
@Test(expected = IllegalArgumentException.class)
@Test
public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
}
@Test(expected = IllegalArgumentException.class)
@@ -271,13 +267,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
String sockJsPath = sessionUrlPrefix+ "jsonp";
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
}
@Test
@@ -289,8 +291,7 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com");
@@ -310,13 +311,21 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
assertEquals("SAMEORIGIN", this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
}