diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java
index 41b85c1daf..cdba41d403 100644
--- a/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java
+++ b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java
@@ -21,35 +21,31 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
- * Contract for handling CORS preflight requests and intercepting CORS simple
- * and actual requests.
+ * A strategy that takes a request and a {@link CorsConfiguration} and updates
+ * the response.
+ *
+ *
This component is not concerned with how a {@code CorsConfiguration} is
+ * selected but rather takes follow-up actions such as applying CORS validation
+ * checks and either rejecting the response or adding CORS headers to the
+ * response.
*
* @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
* @since 4.2
* @see CORS W3C recommandation
+ * @see org.springframework.web.servlet.handler.AbstractHandlerMapping#setCorsProcessor
*/
public interface CorsProcessor {
/**
- * Process a preflight CORS request given a {@link CorsConfiguration}.
- * If the request is not a valid CORS pre-flight request or if it does not
- * comply with the configuration it should be rejected.
- * If the request is valid and complies with the configuration, CORS headers
- * should be added to the response.
+ * Process a request given a {@code CorsConfiguration}.
+ *
+ * @param configuration the applicable CORS configuration, possibly {@code null}
+ * @param request the current request
+ * @param response the current response
* @return {@code false} if the request is rejected, else {@code true}.
*/
- boolean processPreFlightRequest(CorsConfiguration conf, HttpServletRequest request,
- HttpServletResponse response) throws IOException;
-
- /**
- * Process a simple or actual CORS request given a {@link CorsConfiguration}.
- * If the request is not a valid CORS simple or actual request or if it does
- * not comply with the configuration, it should be rejected.
- * If the request is valid and comply with the configuration, this method adds the related
- * CORS headers to the response.
- * @return {@code false} if the request is rejected, else {@code true}.
- */
- boolean processActualRequest(CorsConfiguration conf, HttpServletRequest request,
+ boolean processRequest(CorsConfiguration configuration, HttpServletRequest request,
HttpServletResponse response) throws IOException;
}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java
index e3a985eeb0..1bb8ade9f4 100644
--- a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java
+++ b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java
@@ -34,17 +34,14 @@ public class CorsUtils {
* Returns {@code true} if the request is a valid CORS one.
*/
public static boolean isCorsRequest(HttpServletRequest request) {
- return request.getHeader(HttpHeaders.ORIGIN) != null;
+ return (request.getHeader(HttpHeaders.ORIGIN) != null);
}
/**
* Returns {@code true} if the request is a valid CORS pre-flight one.
*/
public static boolean isPreFlightRequest(HttpServletRequest request) {
- if (!isCorsRequest(request)) {
- return false;
- }
- return request.getMethod().equals(HttpMethod.OPTIONS.name());
+ return (isCorsRequest(request) && request.getMethod().equals(HttpMethod.OPTIONS.name()));
}
}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java
index 6a141897ba..12fb2710b3 100644
--- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java
+++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java
@@ -39,7 +39,11 @@ import org.springframework.util.CollectionUtils;
/**
* Default implementation of {@link CorsProcessor}, as defined by the
- * CORS W3C recommandation.
+ * CORS W3C recommendation.
+ *
+ *
Note that when input {@link CorsConfiguration} is {@code null}, this
+ * implementation does not reject simple or actual requests outright but simply
+ * avoid adding CORS headers to the response.
*
* @author Sebastien Deleuze
* @author Rossen Stoyanhcev
@@ -49,48 +53,37 @@ public class DefaultCorsProcessor implements CorsProcessor {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
-
- protected final Log logger = LogFactory.getLog(getClass());
+ private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class);
@Override
- public boolean processPreFlightRequest(CorsConfiguration config, HttpServletRequest request,
+ public boolean processRequest(CorsConfiguration config, HttpServletRequest request,
HttpServletResponse response) throws IOException {
- Assert.isTrue(CorsUtils.isPreFlightRequest(request));
-
- ServerHttpResponse serverResponse = new ServletServerHttpResponse(response);
- if (responseHasCors(serverResponse)) {
+ if (!CorsUtils.isCorsRequest(request)) {
return true;
}
- ServerHttpRequest serverRequest = new ServletServerHttpRequest(request);
- if (handleInternal(serverRequest, serverResponse, config, true)) {
- serverResponse.flush();
- return true;
- }
-
- return false;
- }
-
- @Override
- public boolean processActualRequest(CorsConfiguration config, HttpServletRequest request,
- HttpServletResponse response) throws IOException {
-
- Assert.isTrue(CorsUtils.isCorsRequest(request) && !CorsUtils.isPreFlightRequest(request));
-
ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response);
+ ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request);
+
if (responseHasCors(serverResponse)) {
return true;
}
- ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request);
- if (handleInternal(serverRequest, serverResponse, config, false)) {
- serverResponse.flush();
- return true;
+ boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
+
+ if (config == null) {
+ if (preFlightRequest) {
+ rejectRequest(serverResponse);
+ return false;
+ }
+ else {
+ return true;
+ }
}
- return false;
+ return handleInternal(serverRequest, serverResponse, config, preFlightRequest);
}
private boolean responseHasCors(ServerHttpResponse response) {
@@ -107,20 +100,33 @@ public class DefaultCorsProcessor implements CorsProcessor {
return hasAllowOrigin;
}
+ /**
+ * Invoked when one of the CORS checks failed.
+ * The default implementation sets the response status to 403 and writes
+ * "Invalid CORS request" to the response.
+ */
+ protected void rejectRequest(ServerHttpResponse response) throws IOException {
+ response.setStatusCode(HttpStatus.FORBIDDEN);
+ response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET));
+ }
+
+ /**
+ * Handle the given request.
+ */
protected boolean handleInternal(ServerHttpRequest request, ServerHttpResponse response,
- CorsConfiguration config, boolean isPreFlight) throws IOException {
+ CorsConfiguration config, boolean preFlightRequest) throws IOException {
String requestOrigin = request.getHeaders().getOrigin();
String allowOrigin = checkOrigin(config, requestOrigin);
- HttpMethod requestMethod = getMethodToUse(request, isPreFlight);
+ HttpMethod requestMethod = getMethodToUse(request, preFlightRequest);
List allowMethods = checkMethods(config, requestMethod);
- List requestHeaders = getHeadersToUse(request, isPreFlight);
+ List requestHeaders = getHeadersToUse(request, preFlightRequest);
List allowHeaders = checkHeaders(config, requestHeaders);
- if (allowOrigin == null || allowMethods == null || (isPreFlight && allowHeaders == null)) {
- handleInvalidCorsRequest(response);
+ if (allowOrigin == null || allowMethods == null || (preFlightRequest && allowHeaders == null)) {
+ rejectRequest(response);
return false;
}
@@ -128,11 +134,11 @@ public class DefaultCorsProcessor implements CorsProcessor {
responseHeaders.setAccessControlAllowOrigin(allowOrigin);
responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN);
- if (isPreFlight) {
+ if (preFlightRequest) {
responseHeaders.setAccessControlAllowMethods(allowMethods);
}
- if (isPreFlight && !allowHeaders.isEmpty()) {
+ if (preFlightRequest && !allowHeaders.isEmpty()) {
responseHeaders.setAccessControlAllowHeaders(allowHeaders);
}
@@ -144,10 +150,11 @@ public class DefaultCorsProcessor implements CorsProcessor {
responseHeaders.setAccessControlAllowCredentials(true);
}
- if (isPreFlight && config.getMaxAge() != null) {
+ if (preFlightRequest && config.getMaxAge() != null) {
responseHeaders.setAccessControlMaxAge(config.getMaxAge());
}
+ response.flush();
return true;
}
@@ -187,14 +194,4 @@ public class DefaultCorsProcessor implements CorsProcessor {
return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList(headers.keySet()));
}
- /**
- * Invoked when one of the CORS checks failed.
- * The default implementation sets the response status to 403 and writes
- * "Invalid CORS request" to the response.
- */
- protected void handleInvalidCorsRequest(ServerHttpResponse response) throws IOException {
- response.setStatusCode(HttpStatus.FORBIDDEN);
- response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET));
- }
-
}
diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java
index bdcc188035..fa60ae5e77 100644
--- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java
+++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java
@@ -32,14 +32,19 @@ import static org.junit.Assert.*;
* Test {@link DefaultCorsProcessor} with simple or preflight CORS request.
*
* @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
*/
public class DefaultCorsProcessorTests {
private MockHttpServletRequest request;
+
private MockHttpServletResponse response;
+
private DefaultCorsProcessor processor;
+
private CorsConfiguration conf;
+
@Before
public void setup() {
this.request = new MockHttpServletRequest();
@@ -51,31 +56,34 @@ public class DefaultCorsProcessorTests {
this.processor = new DefaultCorsProcessor();
}
- @Test(expected = IllegalArgumentException.class)
- public void actualRequestWithoutOriginHeader() throws Exception {
- this.request.setMethod(HttpMethod.GET.name());
- this.processor.processActualRequest(this.conf, request, response);
- }
-
@Test
public void actualRequestWithOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
- this.processor.processActualRequest(this.conf, request, response);
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
- public void actualRequestwithOriginHeaderAndAllowedOrigin() throws Exception {
+ public void actualRequestWithOriginHeaderAndNullConfig() throws Exception {
+ this.request.setMethod(HttpMethod.GET.name());
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
+ this.processor.processRequest(null, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpServletResponse.SC_OK, response.getStatus());
+ }
+
+ @Test
+ public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("*");
- this.processor.processActualRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("*", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -87,10 +95,10 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.setAllowCredentials(true);
- this.processor.processActualRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -101,10 +109,10 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("*");
this.conf.setAllowCredentials(true);
- this.processor.processActualRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -114,8 +122,8 @@ public class DefaultCorsProcessorTests {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/TEST.html");
- this.processor.processActualRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -126,12 +134,12 @@ public class DefaultCorsProcessorTests {
this.conf.addExposedHeader("header1");
this.conf.addExposedHeader("header2");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
- this.processor.processActualRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -141,7 +149,7 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("*");
- this.processor.processPreFlightRequest(this.conf, request, response);
+ this.processor.processRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -151,7 +159,7 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "DELETE");
this.conf.addAllowedOrigin("*");
- this.processor.processPreFlightRequest(this.conf, request, response);
+ this.processor.processRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@@ -161,24 +169,17 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("*");
- this.processor.processPreFlightRequest(this.conf, request, response);
+ this.processor.processRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
assertEquals("GET", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
}
- @Test(expected = IllegalArgumentException.class)
- public void preflightRequestWithoutOriginHeader() throws Exception {
- this.request.setMethod(HttpMethod.OPTIONS.name());
- this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
- this.processor.processPreFlightRequest(this.conf, request, response);
- }
-
@Test
public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@@ -187,8 +188,8 @@ public class DefaultCorsProcessorTests {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@@ -198,8 +199,8 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@@ -214,12 +215,12 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedMethod("PUT");
this.conf.addAllowedHeader("header1");
this.conf.addAllowedHeader("header2");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("*", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
assertEquals("GET,PUT", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
- assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -234,10 +235,10 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.addAllowedHeader("Header1");
this.conf.setAllowCredentials(true);
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -253,8 +254,8 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.addAllowedHeader("Header1");
this.conf.setAllowCredentials(true);
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -269,12 +270,12 @@ public class DefaultCorsProcessorTests {
this.conf.addAllowedHeader("Header2");
this.conf.addAllowedHeader("Header3");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
- assertFalse(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
+ assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@@ -286,13 +287,24 @@ public class DefaultCorsProcessorTests {
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedHeader("*");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
- this.processor.processPreFlightRequest(this.conf, request, response);
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
- assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
- assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
- assertFalse(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
+ this.processor.processRequest(this.conf, request, response);
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
+ assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
+ assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
+ @Test
+ public void preflightRequestWithNullConfig() throws Exception {
+ this.request.setMethod(HttpMethod.OPTIONS.name());
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.conf.addAllowedOrigin("*");
+ this.processor.processRequest(null, request, response);
+ assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
+ }
+
}
\ No newline at end of file
diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java
index 82f687c6a5..b9f7e7c440 100644
--- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java
+++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java
@@ -17,7 +17,7 @@
package org.springframework.web.servlet.config.annotation;
import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -34,6 +34,7 @@ public class CorsConfigurer {
private final List registrations = new ArrayList();
+
/**
* Enable cross origin requests on the specified path patterns. If no path pattern is specified,
* cross-origin request handling is mapped on "/**" .
@@ -47,7 +48,7 @@ public class CorsConfigurer {
}
protected Map getCorsConfigurations() {
- Map configs = new HashMap();
+ Map configs = new LinkedHashMap(this.registrations.size());
for (CorsRegistration registration : this.registrations) {
for (String pathPattern : registration.getPathPatterns()) {
configs.put(pathPattern, registration.getCorsConfiguration());
diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java
index c11bc7d47d..2f63e8afc2 100644
--- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java
+++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java
@@ -239,7 +239,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv
handlerMapping.setOrder(0);
handlerMapping.setInterceptors(getInterceptors());
handlerMapping.setContentNegotiationManager(mvcContentNegotiationManager());
- handlerMapping.setCorsConfigurations(getCorsConfigurations());
+ handlerMapping.setCorsConfiguration(getCorsConfigurations());
PathMatchConfigurer configurer = getPathMatchConfigurer();
if (configurer.isUseSuffixPatternMatch() != null) {
@@ -371,7 +371,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv
handlerMapping.setPathMatcher(mvcPathMatcher());
handlerMapping.setUrlPathHelper(mvcUrlPathHelper());
handlerMapping.setInterceptors(getInterceptors());
- handlerMapping.setCorsConfigurations(getCorsConfigurations());
+ handlerMapping.setCorsConfiguration(getCorsConfigurations());
return handlerMapping;
}
@@ -391,7 +391,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv
BeanNameUrlHandlerMapping mapping = new BeanNameUrlHandlerMapping();
mapping.setOrder(2);
mapping.setInterceptors(getInterceptors());
- mapping.setCorsConfigurations(getCorsConfigurations());
+ mapping.setCorsConfiguration(getCorsConfigurations());
return mapping;
}
@@ -411,7 +411,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv
handlerMapping.setUrlPathHelper(mvcUrlPathHelper());
handlerMapping.setInterceptors(new HandlerInterceptor[] {
new ResourceUrlProviderExposingInterceptor(mvcResourceUrlProvider())});
- handlerMapping.setCorsConfigurations(getCorsConfigurations());
+ handlerMapping.setCorsConfiguration(getCorsConfigurations());
}
else {
handlerMapping = new EmptyHandlerMapping();
diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java
index b97ebf78e2..a68eb3903c 100644
--- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java
+++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java
@@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
@@ -83,7 +84,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
private CorsProcessor corsProcessor = new DefaultCorsProcessor();
- private Map corsConfigurations = new HashMap();
+ private final Map corsConfiguration =
+ new LinkedHashMap();
/**
@@ -199,33 +201,41 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
}
/**
+ * Configure a custom {@link CorsProcessor} to use to apply the matched
+ * {@link CorsConfiguration} for a request.
+ * By default {@link DefaultCorsProcessor} is used.
* @since 4.2
*/
public void setCorsProcessor(CorsProcessor corsProcessor) {
Assert.notNull(corsProcessor, "CorsProcessor must not be null");
this.corsProcessor = corsProcessor;
}
-
+
/**
- * Map the specified {@link CorsConfiguration} to the specified path.
- *
- * @param pathPattern the path to use.
Supports direct URL matches and Ant-style pattern matches.
- * For syntax details, see the {@link org.springframework.util.AntPathMatcher} javadoc.
- * @param config the CORS configuration to use
- * @since 4.2
+ * Return the configured {@link CorsProcessor}.
*/
- public void registerCorsConfiguration(String pathPattern, CorsConfiguration config) {
- this.corsConfigurations.put(pathPattern, config);
+ public CorsProcessor getCorsProcessor() {
+ return this.corsProcessor;
}
/**
- * Set the {@link CorsConfiguration} map.
- *
+ * Set "global" CORS configuration based on URL patterns. By default the first
+ * matching URL pattern is combined with the CORS configuration for the
+ * handler, if any.
* @since 4.2
- * @see #registerCorsConfiguration(String, CorsConfiguration)
*/
- public void setCorsConfigurations(Map corsConfigurations) {
- this.corsConfigurations = corsConfigurations;
+ public void setCorsConfiguration(Map corsConfiguration) {
+ this.corsConfiguration.clear();
+ if (corsConfiguration != null) {
+ this.corsConfiguration.putAll(corsConfiguration);
+ }
+ }
+
+ /**
+ * Get the CORS configuration.
+ */
+ public Map getCorsConfiguration() {
+ return this.corsConfiguration;
}
/**
@@ -353,7 +363,10 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
}
HandlerExecutionChain executionChain = getHandlerExecutionChain(handler, request);
if (CorsUtils.isCorsRequest(request)) {
- executionChain = getCorsHandlerExecutionChain(request, executionChain);
+ CorsConfiguration globalConfig = getCorsConfiguration(request);
+ CorsConfiguration handlerConfig = getCorsConfiguration(handler, request);
+ CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig);
+ executionChain = getCorsHandlerExecutionChain(request, executionChain, config);
}
return executionChain;
}
@@ -413,11 +426,28 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
return chain;
}
+ /**
+ * Find the "global" CORS configuration for the given URL configured via
+ * {@link #setCorsConfiguration(Map)}.
+ * @param request the request
+ * @return the CORS configuration or {@code null}
+ */
+ protected CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
+ String lookupPath = getUrlPathHelper().getLookupPathForRequest(request);
+ for(Map.Entry entry : getCorsConfiguration().entrySet()) {
+ if (getPathMatcher().match(entry.getKey(), lookupPath)) {
+ return entry.getValue();
+ }
+ }
+ return null;
+ }
+
/**
* Retrieve the CORS configuration for the given handler.
* @param handler the handler to check (never {@code null}).
* @param request the current request.
* @return the CORS configuration for the handler or {@code null}.
+ * @since 4.2
*/
protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) {
if (handler instanceof HandlerExecutionChain) {
@@ -431,35 +461,25 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
/**
* Update the HandlerExecutionChain for CORS-related handling.
- *
* For pre-flight requests, the default implementation replaces the selected
* handler with a simple HttpRequestHandler that invokes the configured
* {@link #setCorsProcessor}.
- *
*
For actual requests, the default implementation inserts a
* HandlerInterceptor that makes CORS-related checks and adds CORS headers.
+ * @param request the current request
+ * @param chain the handler chain
+ * @param config the applicable CORS configuration, possibly {@code null}
+ * @since 4.2
*/
protected HandlerExecutionChain getCorsHandlerExecutionChain(HttpServletRequest request,
- HandlerExecutionChain chain) {
-
- CorsConfiguration globalConfig = null;
- String lookupPath = this.urlPathHelper.getLookupPathForRequest(request);
- for(Map.Entry entry : this.corsConfigurations.entrySet()) {
- if(this.pathMatcher.match(entry.getKey(), lookupPath)) {
- globalConfig = entry.getValue();
- }
- }
- CorsConfiguration config = getCorsConfiguration(chain.getHandler(), request);
- config = (globalConfig == null ? config : globalConfig.combine(config));
+ HandlerExecutionChain chain, CorsConfiguration config) {
- if (config != null) {
- if (CorsUtils.isPreFlightRequest(request)) {
- HandlerInterceptor[] interceptors = chain.getInterceptors();
- chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors);
- }
- else {
- chain.addInterceptor(new CorsInterceptor(config));
- }
+ if (CorsUtils.isPreFlightRequest(request)) {
+ HandlerInterceptor[] interceptors = chain.getInterceptors();
+ chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors);
+ }
+ else {
+ chain.addInterceptor(new CorsInterceptor(config));
}
return chain;
}
@@ -469,14 +489,15 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
private final CorsConfiguration config;
-
public PreFlightHandler(CorsConfiguration config) {
this.config = config;
}
@Override
- public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
- corsProcessor.processPreFlightRequest(this.config, request, response);
+ public void handleRequest(HttpServletRequest request, HttpServletResponse response)
+ throws IOException {
+
+ corsProcessor.processRequest(this.config, request, response);
}
}
@@ -484,16 +505,16 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
private final CorsConfiguration config;
-
public CorsInterceptor(CorsConfiguration config) {
this.config = config;
}
@Override
- public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
- return corsProcessor.processActualRequest(this.config, request, response);
- }
+ public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
+ Object handler) throws Exception {
+ return corsProcessor.processRequest(this.config, request, response);
+ }
}
}
diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java
index 2ed5086fd6..944f6e14bd 100644
--- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java
+++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java
@@ -16,48 +16,50 @@
package org.springframework.web.servlet.handler;
+import static org.junit.Assert.*;
+
import java.io.IOException;
+import java.util.Collections;
+
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
-import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
-import org.springframework.web.cors.CorsConfiguration;
-import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.mock.web.test.MockHttpServletRequest;
-import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.StaticWebApplicationContext;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.support.WebContentGenerator;
/**
+ * Unit tests for CORS-related handling in {@link AbstractHandlerMapping}.
* @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
*/
public class CorsAbstractHandlerMappingTests {
private MockHttpServletRequest request;
- private MockHttpServletResponse response;
+
private AbstractHandlerMapping handlerMapping;
- private StaticWebApplicationContext context;
@Before
public void setup() {
- this.context = new StaticWebApplicationContext();
+ StaticWebApplicationContext context = new StaticWebApplicationContext();
this.handlerMapping = new TestHandlerMapping();
- this.handlerMapping.setApplicationContext(this.context);
+ this.handlerMapping.setApplicationContext(context);
this.request = new MockHttpServletRequest();
this.request.setRemoteHost("domain1.com");
- this.response = new MockHttpServletResponse();
}
@Test
@@ -67,6 +69,7 @@ public class CorsAbstractHandlerMappingTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
+ assertNotNull(chain);
assertTrue(chain.getHandler() instanceof SimpleHandler);
}
@@ -77,7 +80,9 @@ public class CorsAbstractHandlerMappingTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
- assertTrue(chain.getHandler() instanceof SimpleHandler);
+ assertNotNull(chain);
+ assertNotNull(chain.getHandler());
+ assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
}
@Test
@@ -87,6 +92,7 @@ public class CorsAbstractHandlerMappingTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
+ assertNotNull(chain);
assertTrue(chain.getHandler() instanceof CorsAwareHandler);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
@@ -100,6 +106,7 @@ public class CorsAbstractHandlerMappingTests {
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
+ assertNotNull(chain);
assertNotNull(chain.getHandler());
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
CorsConfiguration config = getCorsConfiguration(chain, true);
@@ -111,12 +118,13 @@ public class CorsAbstractHandlerMappingTests {
public void actualRequestWithMappedCorsConfiguration() throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
- this.handlerMapping.registerCorsConfiguration("/foo", config);
+ this.handlerMapping.setCorsConfiguration(Collections.singletonMap("/foo", config));
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/foo");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
+ assertNotNull(chain);
assertTrue(chain.getHandler() instanceof SimpleHandler);
config = getCorsConfiguration(chain, false);
assertNotNull(config);
@@ -127,12 +135,13 @@ public class CorsAbstractHandlerMappingTests {
public void preflightRequestWithMappedCorsConfiguration() throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
- this.handlerMapping.registerCorsConfiguration("/foo", config);
+ this.handlerMapping.setCorsConfiguration(Collections.singletonMap("/foo", config));
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/foo");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
+ assertNotNull(chain);
assertNotNull(chain.getHandler());
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
config = getCorsConfiguration(chain, true);