diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java index 41b85c1daf..cdba41d403 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsProcessor.java @@ -21,35 +21,31 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; /** - * Contract for handling CORS preflight requests and intercepting CORS simple - * and actual requests. + * A strategy that takes a request and a {@link CorsConfiguration} and updates + * the response. + * + *

This component is not concerned with how a {@code CorsConfiguration} is + * selected but rather takes follow-up actions such as applying CORS validation + * checks and either rejecting the response or adding CORS headers to the + * response. * * @author Sebastien Deleuze + * @author Rossen Stoyanchev * @since 4.2 * @see CORS W3C recommandation + * @see org.springframework.web.servlet.handler.AbstractHandlerMapping#setCorsProcessor */ public interface CorsProcessor { /** - * Process a preflight CORS request given a {@link CorsConfiguration}. - * If the request is not a valid CORS pre-flight request or if it does not - * comply with the configuration it should be rejected. - * If the request is valid and complies with the configuration, CORS headers - * should be added to the response. + * Process a request given a {@code CorsConfiguration}. + * + * @param configuration the applicable CORS configuration, possibly {@code null} + * @param request the current request + * @param response the current response * @return {@code false} if the request is rejected, else {@code true}. */ - boolean processPreFlightRequest(CorsConfiguration conf, HttpServletRequest request, - HttpServletResponse response) throws IOException; - - /** - * Process a simple or actual CORS request given a {@link CorsConfiguration}. - * If the request is not a valid CORS simple or actual request or if it does - * not comply with the configuration, it should be rejected. - * If the request is valid and comply with the configuration, this method adds the related - * CORS headers to the response. - * @return {@code false} if the request is rejected, else {@code true}. - */ - boolean processActualRequest(CorsConfiguration conf, HttpServletRequest request, + boolean processRequest(CorsConfiguration configuration, HttpServletRequest request, HttpServletResponse response) throws IOException; } diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java index e3a985eeb0..1bb8ade9f4 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java @@ -34,17 +34,14 @@ public class CorsUtils { * Returns {@code true} if the request is a valid CORS one. */ public static boolean isCorsRequest(HttpServletRequest request) { - return request.getHeader(HttpHeaders.ORIGIN) != null; + return (request.getHeader(HttpHeaders.ORIGIN) != null); } /** * Returns {@code true} if the request is a valid CORS pre-flight one. */ public static boolean isPreFlightRequest(HttpServletRequest request) { - if (!isCorsRequest(request)) { - return false; - } - return request.getMethod().equals(HttpMethod.OPTIONS.name()); + return (isCorsRequest(request) && request.getMethod().equals(HttpMethod.OPTIONS.name())); } } diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index 6a141897ba..12fb2710b3 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -39,7 +39,11 @@ import org.springframework.util.CollectionUtils; /** * Default implementation of {@link CorsProcessor}, as defined by the - * CORS W3C recommandation. + * CORS W3C recommendation. + * + *

Note that when input {@link CorsConfiguration} is {@code null}, this + * implementation does not reject simple or actual requests outright but simply + * avoid adding CORS headers to the response. * * @author Sebastien Deleuze * @author Rossen Stoyanhcev @@ -49,48 +53,37 @@ public class DefaultCorsProcessor implements CorsProcessor { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - - protected final Log logger = LogFactory.getLog(getClass()); + private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class); @Override - public boolean processPreFlightRequest(CorsConfiguration config, HttpServletRequest request, + public boolean processRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { - Assert.isTrue(CorsUtils.isPreFlightRequest(request)); - - ServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - if (responseHasCors(serverResponse)) { + if (!CorsUtils.isCorsRequest(request)) { return true; } - ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); - if (handleInternal(serverRequest, serverResponse, config, true)) { - serverResponse.flush(); - return true; - } - - return false; - } - - @Override - public boolean processActualRequest(CorsConfiguration config, HttpServletRequest request, - HttpServletResponse response) throws IOException { - - Assert.isTrue(CorsUtils.isCorsRequest(request) && !CorsUtils.isPreFlightRequest(request)); - ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); + ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (responseHasCors(serverResponse)) { return true; } - ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); - if (handleInternal(serverRequest, serverResponse, config, false)) { - serverResponse.flush(); - return true; + boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); + + if (config == null) { + if (preFlightRequest) { + rejectRequest(serverResponse); + return false; + } + else { + return true; + } } - return false; + return handleInternal(serverRequest, serverResponse, config, preFlightRequest); } private boolean responseHasCors(ServerHttpResponse response) { @@ -107,20 +100,33 @@ public class DefaultCorsProcessor implements CorsProcessor { return hasAllowOrigin; } + /** + * Invoked when one of the CORS checks failed. + * The default implementation sets the response status to 403 and writes + * "Invalid CORS request" to the response. + */ + protected void rejectRequest(ServerHttpResponse response) throws IOException { + response.setStatusCode(HttpStatus.FORBIDDEN); + response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); + } + + /** + * Handle the given request. + */ protected boolean handleInternal(ServerHttpRequest request, ServerHttpResponse response, - CorsConfiguration config, boolean isPreFlight) throws IOException { + CorsConfiguration config, boolean preFlightRequest) throws IOException { String requestOrigin = request.getHeaders().getOrigin(); String allowOrigin = checkOrigin(config, requestOrigin); - HttpMethod requestMethod = getMethodToUse(request, isPreFlight); + HttpMethod requestMethod = getMethodToUse(request, preFlightRequest); List allowMethods = checkMethods(config, requestMethod); - List requestHeaders = getHeadersToUse(request, isPreFlight); + List requestHeaders = getHeadersToUse(request, preFlightRequest); List allowHeaders = checkHeaders(config, requestHeaders); - if (allowOrigin == null || allowMethods == null || (isPreFlight && allowHeaders == null)) { - handleInvalidCorsRequest(response); + if (allowOrigin == null || allowMethods == null || (preFlightRequest && allowHeaders == null)) { + rejectRequest(response); return false; } @@ -128,11 +134,11 @@ public class DefaultCorsProcessor implements CorsProcessor { responseHeaders.setAccessControlAllowOrigin(allowOrigin); responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); - if (isPreFlight) { + if (preFlightRequest) { responseHeaders.setAccessControlAllowMethods(allowMethods); } - if (isPreFlight && !allowHeaders.isEmpty()) { + if (preFlightRequest && !allowHeaders.isEmpty()) { responseHeaders.setAccessControlAllowHeaders(allowHeaders); } @@ -144,10 +150,11 @@ public class DefaultCorsProcessor implements CorsProcessor { responseHeaders.setAccessControlAllowCredentials(true); } - if (isPreFlight && config.getMaxAge() != null) { + if (preFlightRequest && config.getMaxAge() != null) { responseHeaders.setAccessControlMaxAge(config.getMaxAge()); } + response.flush(); return true; } @@ -187,14 +194,4 @@ public class DefaultCorsProcessor implements CorsProcessor { return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList(headers.keySet())); } - /** - * Invoked when one of the CORS checks failed. - * The default implementation sets the response status to 403 and writes - * "Invalid CORS request" to the response. - */ - protected void handleInvalidCorsRequest(ServerHttpResponse response) throws IOException { - response.setStatusCode(HttpStatus.FORBIDDEN); - response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); - } - } diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java index bdcc188035..fa60ae5e77 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -32,14 +32,19 @@ import static org.junit.Assert.*; * Test {@link DefaultCorsProcessor} with simple or preflight CORS request. * * @author Sebastien Deleuze + * @author Rossen Stoyanchev */ public class DefaultCorsProcessorTests { private MockHttpServletRequest request; + private MockHttpServletResponse response; + private DefaultCorsProcessor processor; + private CorsConfiguration conf; + @Before public void setup() { this.request = new MockHttpServletRequest(); @@ -51,31 +56,34 @@ public class DefaultCorsProcessorTests { this.processor = new DefaultCorsProcessor(); } - @Test(expected = IllegalArgumentException.class) - public void actualRequestWithoutOriginHeader() throws Exception { - this.request.setMethod(HttpMethod.GET.name()); - this.processor.processActualRequest(this.conf, request, response); - } - @Test public void actualRequestWithOriginHeader() throws Exception { this.request.setMethod(HttpMethod.GET.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); - this.processor.processActualRequest(this.conf, request, response); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); } @Test - public void actualRequestwithOriginHeaderAndAllowedOrigin() throws Exception { + public void actualRequestWithOriginHeaderAndNullConfig() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); + this.processor.processRequest(null, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpServletResponse.SC_OK, response.getStatus()); + } + + @Test + public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception { this.request.setMethod(HttpMethod.GET.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.conf.addAllowedOrigin("*"); - this.processor.processActualRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("*", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -87,10 +95,10 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedOrigin("http://domain2.com/test.html"); this.conf.addAllowedOrigin("http://domain2.com/logout.html"); this.conf.setAllowCredentials(true); - this.processor.processActualRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -101,10 +109,10 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.conf.addAllowedOrigin("*"); this.conf.setAllowCredentials(true); - this.processor.processActualRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -114,8 +122,8 @@ public class DefaultCorsProcessorTests { this.request.setMethod(HttpMethod.GET.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.conf.addAllowedOrigin("http://domain2.com/TEST.html"); - this.processor.processActualRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -126,12 +134,12 @@ public class DefaultCorsProcessorTests { this.conf.addExposedHeader("header1"); this.conf.addExposedHeader("header2"); this.conf.addAllowedOrigin("http://domain2.com/test.html"); - this.processor.processActualRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -141,7 +149,7 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.conf.addAllowedOrigin("*"); - this.processor.processPreFlightRequest(this.conf, request, response); + this.processor.processRequest(this.conf, request, response); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -151,7 +159,7 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "DELETE"); this.conf.addAllowedOrigin("*"); - this.processor.processPreFlightRequest(this.conf, request, response); + this.processor.processRequest(this.conf, request, response); assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); } @@ -161,24 +169,17 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.conf.addAllowedOrigin("*"); - this.processor.processPreFlightRequest(this.conf, request, response); + this.processor.processRequest(this.conf, request, response); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); assertEquals("GET", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); } - @Test(expected = IllegalArgumentException.class) - public void preflightRequestWithoutOriginHeader() throws Exception { - this.request.setMethod(HttpMethod.OPTIONS.name()); - this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); - this.processor.processPreFlightRequest(this.conf, request, response); - } - @Test public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception { this.request.setMethod(HttpMethod.OPTIONS.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); } @@ -187,8 +188,8 @@ public class DefaultCorsProcessorTests { this.request.setMethod(HttpMethod.OPTIONS.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); } @@ -198,8 +199,8 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); } @@ -214,12 +215,12 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedMethod("PUT"); this.conf.addAllowedHeader("header1"); this.conf.addAllowedHeader("header2"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("*", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,PUT", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); - assertFalse(response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -234,10 +235,10 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedOrigin("http://domain2.com/logout.html"); this.conf.addAllowedHeader("Header1"); this.conf.setAllowCredentials(true); - this.processor.processPreFlightRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -253,8 +254,8 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedOrigin("http://domain2.com/logout.html"); this.conf.addAllowedHeader("Header1"); this.conf.setAllowCredentials(true); - this.processor.processPreFlightRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com/test.html", response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -269,12 +270,12 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedHeader("Header2"); this.conf.addAllowedHeader("Header3"); this.conf.addAllowedOrigin("http://domain2.com/test.html"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); - assertFalse(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } @@ -286,13 +287,24 @@ public class DefaultCorsProcessorTests { this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.conf.addAllowedHeader("*"); this.conf.addAllowedOrigin("http://domain2.com/test.html"); - this.processor.processPreFlightRequest(this.conf, request, response); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertTrue(response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); - assertTrue(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); - assertFalse(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); + this.processor.processRequest(this.conf, request, response); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); assertEquals(HttpServletResponse.SC_OK, response.getStatus()); } + @Test + public void preflightRequestWithNullConfig() throws Exception { + this.request.setMethod(HttpMethod.OPTIONS.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + this.processor.processRequest(null, request, response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus()); + } + } \ No newline at end of file diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java index 82f687c6a5..b9f7e7c440 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/CorsConfigurer.java @@ -17,7 +17,7 @@ package org.springframework.web.servlet.config.annotation; import java.util.ArrayList; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -34,6 +34,7 @@ public class CorsConfigurer { private final List registrations = new ArrayList(); + /** * Enable cross origin requests on the specified path patterns. If no path pattern is specified, * cross-origin request handling is mapped on "/**" . @@ -47,7 +48,7 @@ public class CorsConfigurer { } protected Map getCorsConfigurations() { - Map configs = new HashMap(); + Map configs = new LinkedHashMap(this.registrations.size()); for (CorsRegistration registration : this.registrations) { for (String pathPattern : registration.getPathPatterns()) { configs.put(pathPattern, registration.getCorsConfiguration()); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java index c11bc7d47d..2f63e8afc2 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java @@ -239,7 +239,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv handlerMapping.setOrder(0); handlerMapping.setInterceptors(getInterceptors()); handlerMapping.setContentNegotiationManager(mvcContentNegotiationManager()); - handlerMapping.setCorsConfigurations(getCorsConfigurations()); + handlerMapping.setCorsConfiguration(getCorsConfigurations()); PathMatchConfigurer configurer = getPathMatchConfigurer(); if (configurer.isUseSuffixPatternMatch() != null) { @@ -371,7 +371,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv handlerMapping.setPathMatcher(mvcPathMatcher()); handlerMapping.setUrlPathHelper(mvcUrlPathHelper()); handlerMapping.setInterceptors(getInterceptors()); - handlerMapping.setCorsConfigurations(getCorsConfigurations()); + handlerMapping.setCorsConfiguration(getCorsConfigurations()); return handlerMapping; } @@ -391,7 +391,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv BeanNameUrlHandlerMapping mapping = new BeanNameUrlHandlerMapping(); mapping.setOrder(2); mapping.setInterceptors(getInterceptors()); - mapping.setCorsConfigurations(getCorsConfigurations()); + mapping.setCorsConfiguration(getCorsConfigurations()); return mapping; } @@ -411,7 +411,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv handlerMapping.setUrlPathHelper(mvcUrlPathHelper()); handlerMapping.setInterceptors(new HandlerInterceptor[] { new ResourceUrlProviderExposingInterceptor(mvcResourceUrlProvider())}); - handlerMapping.setCorsConfigurations(getCorsConfigurations()); + handlerMapping.setCorsConfiguration(getCorsConfigurations()); } else { handlerMapping = new EmptyHandlerMapping(); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java index b97ebf78e2..a68eb3903c 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; @@ -83,7 +84,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport private CorsProcessor corsProcessor = new DefaultCorsProcessor(); - private Map corsConfigurations = new HashMap(); + private final Map corsConfiguration = + new LinkedHashMap(); /** @@ -199,33 +201,41 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport } /** + * Configure a custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. + *

By default {@link DefaultCorsProcessor} is used. * @since 4.2 */ public void setCorsProcessor(CorsProcessor corsProcessor) { Assert.notNull(corsProcessor, "CorsProcessor must not be null"); this.corsProcessor = corsProcessor; } - + /** - * Map the specified {@link CorsConfiguration} to the specified path. - * - * @param pathPattern the path to use.

Supports direct URL matches and Ant-style pattern matches. - * For syntax details, see the {@link org.springframework.util.AntPathMatcher} javadoc. - * @param config the CORS configuration to use - * @since 4.2 + * Return the configured {@link CorsProcessor}. */ - public void registerCorsConfiguration(String pathPattern, CorsConfiguration config) { - this.corsConfigurations.put(pathPattern, config); + public CorsProcessor getCorsProcessor() { + return this.corsProcessor; } /** - * Set the {@link CorsConfiguration} map. - * + * Set "global" CORS configuration based on URL patterns. By default the first + * matching URL pattern is combined with the CORS configuration for the + * handler, if any. * @since 4.2 - * @see #registerCorsConfiguration(String, CorsConfiguration) */ - public void setCorsConfigurations(Map corsConfigurations) { - this.corsConfigurations = corsConfigurations; + public void setCorsConfiguration(Map corsConfiguration) { + this.corsConfiguration.clear(); + if (corsConfiguration != null) { + this.corsConfiguration.putAll(corsConfiguration); + } + } + + /** + * Get the CORS configuration. + */ + public Map getCorsConfiguration() { + return this.corsConfiguration; } /** @@ -353,7 +363,10 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport } HandlerExecutionChain executionChain = getHandlerExecutionChain(handler, request); if (CorsUtils.isCorsRequest(request)) { - executionChain = getCorsHandlerExecutionChain(request, executionChain); + CorsConfiguration globalConfig = getCorsConfiguration(request); + CorsConfiguration handlerConfig = getCorsConfiguration(handler, request); + CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig); + executionChain = getCorsHandlerExecutionChain(request, executionChain, config); } return executionChain; } @@ -413,11 +426,28 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport return chain; } + /** + * Find the "global" CORS configuration for the given URL configured via + * {@link #setCorsConfiguration(Map)}. + * @param request the request + * @return the CORS configuration or {@code null} + */ + protected CorsConfiguration getCorsConfiguration(HttpServletRequest request) { + String lookupPath = getUrlPathHelper().getLookupPathForRequest(request); + for(Map.Entry entry : getCorsConfiguration().entrySet()) { + if (getPathMatcher().match(entry.getKey(), lookupPath)) { + return entry.getValue(); + } + } + return null; + } + /** * Retrieve the CORS configuration for the given handler. * @param handler the handler to check (never {@code null}). * @param request the current request. * @return the CORS configuration for the handler or {@code null}. + * @since 4.2 */ protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) { if (handler instanceof HandlerExecutionChain) { @@ -431,35 +461,25 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport /** * Update the HandlerExecutionChain for CORS-related handling. - * *

For pre-flight requests, the default implementation replaces the selected * handler with a simple HttpRequestHandler that invokes the configured * {@link #setCorsProcessor}. - * *

For actual requests, the default implementation inserts a * HandlerInterceptor that makes CORS-related checks and adds CORS headers. + * @param request the current request + * @param chain the handler chain + * @param config the applicable CORS configuration, possibly {@code null} + * @since 4.2 */ protected HandlerExecutionChain getCorsHandlerExecutionChain(HttpServletRequest request, - HandlerExecutionChain chain) { - - CorsConfiguration globalConfig = null; - String lookupPath = this.urlPathHelper.getLookupPathForRequest(request); - for(Map.Entry entry : this.corsConfigurations.entrySet()) { - if(this.pathMatcher.match(entry.getKey(), lookupPath)) { - globalConfig = entry.getValue(); - } - } - CorsConfiguration config = getCorsConfiguration(chain.getHandler(), request); - config = (globalConfig == null ? config : globalConfig.combine(config)); + HandlerExecutionChain chain, CorsConfiguration config) { - if (config != null) { - if (CorsUtils.isPreFlightRequest(request)) { - HandlerInterceptor[] interceptors = chain.getInterceptors(); - chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors); - } - else { - chain.addInterceptor(new CorsInterceptor(config)); - } + if (CorsUtils.isPreFlightRequest(request)) { + HandlerInterceptor[] interceptors = chain.getInterceptors(); + chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors); + } + else { + chain.addInterceptor(new CorsInterceptor(config)); } return chain; } @@ -469,14 +489,15 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport private final CorsConfiguration config; - public PreFlightHandler(CorsConfiguration config) { this.config = config; } @Override - public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws IOException { - corsProcessor.processPreFlightRequest(this.config, request, response); + public void handleRequest(HttpServletRequest request, HttpServletResponse response) + throws IOException { + + corsProcessor.processRequest(this.config, request, response); } } @@ -484,16 +505,16 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport private final CorsConfiguration config; - public CorsInterceptor(CorsConfiguration config) { this.config = config; } @Override - public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { - return corsProcessor.processActualRequest(this.config, request, response); - } + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, + Object handler) throws Exception { + return corsProcessor.processRequest(this.config, request, response); + } } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 2ed5086fd6..944f6e14bd 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -16,48 +16,50 @@ package org.springframework.web.servlet.handler; +import static org.junit.Assert.*; + import java.io.IOException; +import java.util.Collections; + import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import static org.junit.Assert.*; import org.junit.Before; import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; -import org.springframework.web.cors.CorsConfiguration; -import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.mock.web.test.MockHttpServletRequest; -import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.web.HttpRequestHandler; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.support.WebContentGenerator; /** + * Unit tests for CORS-related handling in {@link AbstractHandlerMapping}. * @author Sebastien Deleuze + * @author Rossen Stoyanchev */ public class CorsAbstractHandlerMappingTests { private MockHttpServletRequest request; - private MockHttpServletResponse response; + private AbstractHandlerMapping handlerMapping; - private StaticWebApplicationContext context; @Before public void setup() { - this.context = new StaticWebApplicationContext(); + StaticWebApplicationContext context = new StaticWebApplicationContext(); this.handlerMapping = new TestHandlerMapping(); - this.handlerMapping.setApplicationContext(this.context); + this.handlerMapping.setApplicationContext(context); this.request = new MockHttpServletRequest(); this.request.setRemoteHost("domain1.com"); - this.response = new MockHttpServletResponse(); } @Test @@ -67,6 +69,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); } @@ -77,7 +80,9 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); - assertTrue(chain.getHandler() instanceof SimpleHandler); + assertNotNull(chain); + assertNotNull(chain.getHandler()); + assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); } @Test @@ -87,6 +92,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof CorsAwareHandler); CorsConfiguration config = getCorsConfiguration(chain, false); assertNotNull(config); @@ -100,6 +106,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertNotNull(chain.getHandler()); assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); CorsConfiguration config = getCorsConfiguration(chain, true); @@ -111,12 +118,13 @@ public class CorsAbstractHandlerMappingTests { public void actualRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); - this.handlerMapping.registerCorsConfiguration("/foo", config); + this.handlerMapping.setCorsConfiguration(Collections.singletonMap("/foo", config)); this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); config = getCorsConfiguration(chain, false); assertNotNull(config); @@ -127,12 +135,13 @@ public class CorsAbstractHandlerMappingTests { public void preflightRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); - this.handlerMapping.registerCorsConfiguration("/foo", config); + this.handlerMapping.setCorsConfiguration(Collections.singletonMap("/foo", config)); this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertNotNull(chain.getHandler()); assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); config = getCorsConfiguration(chain, true);