Allow "ws" and "wss" for isValidCorsOrigin checks

Issue: SPR-12956
This commit is contained in:
Rossen Stoyanchev
2015-05-03 10:23:13 +02:00
parent 222f6998e4
commit 68ecb92d1f
7 changed files with 117 additions and 148 deletions

View File

@@ -17,7 +17,9 @@
package org.springframework.web.socket.server.support;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
@@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler;
public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
public void invalidInput() {
new OriginHandshakeInterceptor(null);
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
}
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
}
@Test
public void validAllowedOrigins() {
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
List<String> allowed = Collections.singletonList("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
List<String> allowed = Collections.singletonList("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("*"));
interceptor.setAllowedOrigins(Collections.singletonList("*"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}

View File

@@ -16,11 +16,14 @@
package org.springframework.web.socket.sockjs.transport.handler;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
@@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig;
import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
/**
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
*
@@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
}
@Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() {
public void invalidAllowedOrigins() {
this.service.setAllowedOrigins(null);
}
@Test
public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
this.service.setAllowedOrigins(Arrays.asList("domain.com"));
}
@Test
public void validAllowedOrigins() {
this.service.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test
public void customizedTransportHandlerList() {
TransportHandlingSockJsService service = new TransportHandlingSockJsService(
@@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
jsonpService.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*"));
jsonpService.setAllowedOrigins(Collections.singletonList("*"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
@@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
List<String> allowed = Collections.singletonList("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
wsService.setHandshakeInterceptors(Collections.singletonList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
@@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
this.service.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.setAllowedOrigins(Collections.singletonList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));