From e3060981551013c57486b72cfd6637d9b8a90292 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 21 Apr 2015 15:20:38 -0400 Subject: [PATCH] Encapsulate CORS checking within CorsConfiguration CorsConfiguration now provides methods to check and determine the allowed origin, method, and headers according to its own configuration. This simplifies significantly the work that needs to be done from DefaultCorsProcessor. However an alternative CorsProcessor can still access the raw CorsConfiguration and perform its own checks. Issue: SPR-12885 --- .../server/ServletServerHttpResponse.java | 7 +- .../web/cors/CorsConfiguration.java | 258 +++++++++++------- .../web/cors/DefaultCorsProcessor.java | 252 +++++++---------- .../web/cors/CorsConfigurationTests.java | 106 +++++++ .../web/cors/DefaultCorsProcessorTests.java | 2 +- .../CorsAbstractHandlerMappingTests.java | 1 - .../method/annotation/CrossOriginTests.java | 19 +- 7 files changed, 371 insertions(+), 274 deletions(-) create mode 100644 spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index 180ef8576d..a54112db9b 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -49,6 +49,8 @@ public class ServletServerHttpResponse implements ServerHttpResponse { private boolean headersWritten = false; + private boolean bodyUsed = false; + /** * Construct a new instance of the ServletServerHttpResponse based on the given {@link HttpServletResponse}. @@ -80,6 +82,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @Override public OutputStream getBody() throws IOException { + this.bodyUsed = true; writeHeaders(); return this.servletResponse.getOutputStream(); } @@ -87,7 +90,9 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @Override public void flush() throws IOException { writeHeaders(); - this.servletResponse.flushBuffer(); + if (this.bodyUsed) { + this.servletResponse.flushBuffer(); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java index d4fabd101f..c01b030a81 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java @@ -21,9 +21,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.springframework.http.HttpMethod; + /** - * Represents the CORS configuration that stores various properties used to check if a - * CORS request is allowed and to generate CORS response headers. + * A container for CORS configuration also providing methods to check actual or + * or requested origin, HTTP method, and headers. * * @author Sebastien Deleuze * @author Rossen Stoyanchev @@ -45,130 +47,79 @@ public class CorsConfiguration { private Long maxAge; + /** + * Default constructor. + */ public CorsConfiguration() { } - public CorsConfiguration(CorsConfiguration config) { - if (config.allowedOrigins != null) { - this.allowedOrigins = new ArrayList(config.allowedOrigins); - } - if (config.allowCredentials != null) { - this.allowCredentials = config.allowCredentials; - } - if (config.exposedHeaders != null) { - this.exposedHeaders = new ArrayList(config.exposedHeaders); - } - if (config.allowedMethods != null) { - this.allowedMethods = new ArrayList(config.allowedMethods); - } - if (config.allowedHeaders != null) { - this.allowedHeaders = new ArrayList(config.allowedHeaders); - } - if (config.maxAge != null) { - this.maxAge = config.maxAge; - } - } - - public CorsConfiguration combine(CorsConfiguration other) { - CorsConfiguration config = new CorsConfiguration(this); - - if (other.getAllowedOrigins() != null) { - config.setAllowedOrigins(other.getAllowedOrigins()); - } - if (other.getAllowedMethods() != null) { - config.setAllowedMethods(other.getAllowedMethods()); - } - if (other.getAllowedHeaders() != null) { - config.setAllowedHeaders(other.getAllowedHeaders()); - } - if (other.getExposedHeaders() != null) { - config.setExposedHeaders(other.getExposedHeaders()); - } - if (other.getMaxAge() != null) { - config.setMaxAge(other.getMaxAge()); - } - if (other.isAllowCredentials() != null) { - config.setAllowCredentials(other.isAllowCredentials()); - } - return config; + /** + * Configure origins to allow, e.g. "http://domain1.com". The special value + * "*" allows all domains. + *

By default this is not set. + */ + public void setAllowedOrigins(List origins) { + this.allowedOrigins = origins; } /** - * @see #setAllowedOrigins(java.util.List) + * Add an origin to allow. */ - public List getAllowedOrigins() { - if (this.allowedOrigins != null) { - return this.allowedOrigins.contains("*") ? Arrays.asList("*") : Collections.unmodifiableList(this.allowedOrigins); - } - return null; - } - - /** - * Set allowed allowedOrigins that will define Access-Control-Allow-Origin response - * header values (mandatory). For example "http://domain1.com", "http://domain2.com" ... - * "*" means that all domains are allowed. - */ - public void setAllowedOrigins(List allowedOrigins) { - this.allowedOrigins = allowedOrigins; - } - - /** - * @see #setAllowedOrigins(java.util.List) - */ - public void addAllowedOrigin(String allowedOrigin) { + public void addAllowedOrigin(String origin) { if (this.allowedOrigins == null) { this.allowedOrigins = new ArrayList(); } - this.allowedOrigins.add(allowedOrigin); + this.allowedOrigins.add(origin); } /** - * @see #setAllowedMethods(java.util.List) + * Return the configured origins to allow, possibly {@code null}. */ - public List getAllowedMethods() { - return this.allowedMethods == null ? null : Collections.unmodifiableList(this.allowedMethods); + public List getAllowedOrigins() { + return this.allowedOrigins; } /** - * Set allow methods that will define Access-Control-Allow-Methods response header - * values. For example "GET", "POST", "PUT" ... "*" means that all methods requested - * by the client are allowed. If not set, allowed method is set to "GET". - * + * Configure HTTP methods to allow, e.g. "GET", "POST", "PUT". The special + * value "*" allows all method. When not set only "GET is allowed. + *

By default this is not set. */ - public void setAllowedMethods(List allowedMethods) { - this.allowedMethods = allowedMethods; + public void setAllowedMethods(List methods) { + this.allowedMethods = methods; } /** - * @see #setAllowedMethods(java.util.List) + * Add an HTTP method to allow. */ - public void addAllowedMethod(String allowedMethod) { + public void addAllowedMethod(String method) { if (this.allowedMethods == null) { this.allowedMethods = new ArrayList(); } - this.allowedMethods.add(allowedMethod); + this.allowedMethods.add(method); } /** - * @see #setAllowedHeaders(java.util.List) + * Return the allowed HTTP methods, possibly {@code null} in which case only + * HTTP GET is allowed. */ - public List getAllowedHeaders() { - return this.allowedHeaders == null ? null : Collections.unmodifiableList(this.allowedHeaders); + public List getAllowedMethods() { + return this.allowedMethods; } /** - * Set a list of request headers that will define Access-Control-Allow-Methods response - * header values. If a header field name is one of the following, it is not required - * to be listed: Cache-Control, Content-Language, Expires, Last-Modified, Pragma. - * "*" means that all headers asked by the client will be allowed. + * Configure the list of headers that a pre-flight request can list as allowed + * for use during an actual request. The special value of "*" allows actual + * requests to send any header. A header name is not required to be listed if + * it is one of: Cache-Control, Content-Language, Expires, Last-Modified, Pragma. + *

By default this is not set. */ public void setAllowedHeaders(List allowedHeaders) { this.allowedHeaders = allowedHeaders; } /** - * @see #setAllowedHeaders(java.util.List) + * Add one actual request header to allow. */ public void addAllowedHeader(String allowedHeader) { if (this.allowedHeaders == null) { @@ -178,23 +129,24 @@ public class CorsConfiguration { } /** - * @see #setExposedHeaders(java.util.List) + * Return the allowed actual request headers, possibly {@code null}. */ - public List getExposedHeaders() { - return this.exposedHeaders == null ? null : Collections.unmodifiableList(this.exposedHeaders); + public List getAllowedHeaders() { + return this.allowedHeaders; } /** - * Set a list of response headers other than simple headers that the resource might use - * and can be exposed. Simple response headers are: Cache-Control, Content-Language, - * Content-Type, Expires, Last-Modified, Pragma. + * Configure the list of response headers other than simple headers (i.e. + * Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, + * Pragma) that an actual response might have and can be exposed. + *

By default this is not set. */ public void setExposedHeaders(List exposedHeaders) { this.exposedHeaders = exposedHeaders; } /** - * @see #setExposedHeaders(java.util.List) + * Add a single response header to expose. */ public void addExposedHeader(String exposedHeader) { if (this.exposedHeaders == null) { @@ -204,33 +156,131 @@ public class CorsConfiguration { } /** - * @see #setAllowCredentials(Boolean) + * Return the configured response headers to expose, possibly {@code null}. */ - public Boolean isAllowCredentials() { - return this.allowCredentials; + public List getExposedHeaders() { + return this.exposedHeaders; } /** - * Indicates whether the resource supports user credentials. - * Set the value of Access-Control-Allow-Credentials response header. + * Whether user credentials are supported. + *

By default this is not set (i.e. user credentials not supported). */ public void setAllowCredentials(Boolean allowCredentials) { this.allowCredentials = allowCredentials; } /** - * @see #setMaxAge(Long) + * Return the configured allowCredentials, possibly {@code null}. */ - public Long getMaxAge() { - return maxAge; + public Boolean getAllowCredentials() { + return this.allowCredentials; } /** - * Indicates how long (seconds) the results of a preflight request can be cached - * in a preflight result cache. + * Configure how long, in seconds, the response from a pre-flight request + * can be cached by clients. + *

By default this is not set. */ public void setMaxAge(Long maxAge) { this.maxAge = maxAge; } + /** + * Return the configure maxAge value, possibly {@code null}. + */ + public Long getMaxAge() { + return maxAge; + } + + + /** + * Check the origin of the request against the configured allowed origins. + * @param requestOrigin the origin to check. + * @return the origin to use for the response, possibly {@code null} which + * means the request origin is not allowed. + */ + public String checkOrigin(String requestOrigin) { + if (requestOrigin == null) { + return null; + } + List allowedOrigins = this.allowedOrigins == null ? + new ArrayList() : this.allowedOrigins; + if (allowedOrigins.contains("*")) { + if (this.allowCredentials == null || !this.allowCredentials) { + return "*"; + } else { + return requestOrigin; + } + } + for (String allowedOrigin : allowedOrigins) { + if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { + return requestOrigin; + } + } + return null; + } + + /** + * Check the request HTTP method (or the method from the + * Access-Control-Request-Method header on a pre-flight request) against the + * configured allowed methods. + * @param requestMethod the HTTP method to check. + * @return the list of HTTP methods to list in the response of a pre-flight + * request, or {@code null} if the requestMethod is not allowed. + */ + public List checkHttpMethod(HttpMethod requestMethod) { + if (requestMethod == null) { + return null; + } + List allowedMethods = this.allowedMethods == null ? + new ArrayList() : this.allowedMethods; + if (allowedMethods.contains("*")) { + return Arrays.asList(requestMethod); + } + if (allowedMethods.isEmpty()) { + allowedMethods.add(HttpMethod.GET.name()); + } + List result = new ArrayList(allowedMethods.size()); + boolean allowed = false; + for (String method : allowedMethods) { + if (method.equals(requestMethod.name())) { + allowed = true; + } + result.add(HttpMethod.valueOf(method)); + } + return allowed ? result : null; + } + + /** + * Check the request headers (or the headers listed in the + * Access-Control-Request-Headers of a pre-flight request) against the + * configured allowed headers. + * @param requestHeaders the headers to check. + * @return the list of allowed headers to list in the response of a pre-flight + * request, or {@code null} if a requestHeader is not allowed. + */ + public List checkHeaders(List requestHeaders) { + if (requestHeaders == null) { + return null; + } + if (requestHeaders.isEmpty()) { + return Collections.emptyList(); + } + List allowedHeaders = this.allowedHeaders == null ? + new ArrayList() : this.allowedHeaders; + boolean allowAnyHeader = allowedHeaders.contains("*"); + List result = new ArrayList(); + for (String requestHeader : requestHeaders) { + requestHeader = requestHeader.trim(); + for (String allowedHeader : allowedHeaders) { + if (allowAnyHeader || requestHeader.equalsIgnoreCase(allowedHeader)) { + result.add(requestHeader); + break; + } + } + } + return result.isEmpty() ? null : result; + } + } diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index e441604b1b..6a141897ba 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -19,7 +19,6 @@ package org.springframework.web.cors; import java.io.IOException; import java.nio.charset.Charset; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import javax.servlet.http.HttpServletRequest; @@ -36,12 +35,14 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Default implementation of {@link CorsProcessor}, as defined by the * CORS W3C recommandation. * * @author Sebastien Deleuze + * @author Rossen Stoyanhcev * @since 4.2 */ public class DefaultCorsProcessor implements CorsProcessor { @@ -56,23 +57,19 @@ public class DefaultCorsProcessor implements CorsProcessor { public boolean processPreFlightRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { - ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + Assert.isTrue(CorsUtils.isPreFlightRequest(request)); + ServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - boolean isPreFlight = CorsUtils.isPreFlightRequest(request); - Assert.isTrue(isPreFlight); - if (skip(serverResponse)) { + if (responseHasCors(serverResponse)) { return true; } - if (check(serverRequest, serverResponse, config, isPreFlight)) { - setAllowOrigin(serverRequest, serverResponse, config.getAllowedOrigins(), config.isAllowCredentials()); - setAllowCredentials(serverResponse, config.isAllowCredentials()); - setAllowMethods(serverRequest, serverResponse, config.getAllowedMethods()); - setAllowHeadersHeader(serverRequest, serverResponse, config.getAllowedHeaders()); - setMaxAgeHeader(serverResponse, config.getMaxAge()); - serverResponse.close(); + ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (handleInternal(serverRequest, serverResponse, config, true)) { + serverResponse.flush(); return true; } + return false; } @@ -80,185 +77,124 @@ public class DefaultCorsProcessor implements CorsProcessor { public boolean processActualRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { - ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); - ServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - boolean isPreFlight = CorsUtils.isPreFlightRequest(request); - Assert.isTrue(CorsUtils.isCorsRequest(request) && !isPreFlight); - if (skip(serverResponse)) { + Assert.isTrue(CorsUtils.isCorsRequest(request) && !CorsUtils.isPreFlightRequest(request)); + + ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); + if (responseHasCors(serverResponse)) { return true; } - if (check(serverRequest, serverResponse, config, isPreFlight)) { - setAllowOrigin(serverRequest, serverResponse, config.getAllowedOrigins(), config.isAllowCredentials()); - setAllowCredentials(serverResponse, config.isAllowCredentials()); - setExposeHeadersHeader(serverResponse, config.getExposedHeaders()); - serverResponse.close(); + ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (handleInternal(serverRequest, serverResponse, config, false)) { + serverResponse.flush(); return true; } + return false; } - private boolean skip(ServerHttpResponse response) { - if (hasAllowOriginHeader(response)) { - logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); - return true; - } - return false; - } - - private boolean check(ServerHttpRequest request, ServerHttpResponse response, - CorsConfiguration config, boolean isPreFlight) throws IOException { - - if (!checkOrigin(request, config.getAllowedOrigins()) || - !checkMethod(request, config.getAllowedMethods(), isPreFlight) || - !checkHeaders(request, config.getAllowedHeaders(), isPreFlight)) { - response.setStatusCode(HttpStatus.FORBIDDEN); - response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); - return false; - } - return true; - } - - private boolean hasAllowOriginHeader(ServerHttpResponse response) { - boolean hasCorsResponseHeaders = false; + private boolean responseHasCors(ServerHttpResponse response) { + boolean hasAllowOrigin = false; try { - // Perhaps a CORS Filter has already added this? - hasCorsResponseHeaders = response.getHeaders().getAccessControlAllowOrigin() != null; + hasAllowOrigin = (response.getHeaders().getAccessControlAllowOrigin() != null); } catch (NullPointerException npe) { - // See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 + // SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 } - return hasCorsResponseHeaders; + if (hasAllowOrigin) { + logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); + } + return hasAllowOrigin; } - private boolean checkOrigin(ServerHttpRequest request, List allowedOrigins) { - String originHeader = request.getHeaders().getOrigin(); - if (originHeader == null || allowedOrigins == null) { + protected boolean handleInternal(ServerHttpRequest request, ServerHttpResponse response, + CorsConfiguration config, boolean isPreFlight) throws IOException { + + String requestOrigin = request.getHeaders().getOrigin(); + String allowOrigin = checkOrigin(config, requestOrigin); + + HttpMethod requestMethod = getMethodToUse(request, isPreFlight); + List allowMethods = checkMethods(config, requestMethod); + + List requestHeaders = getHeadersToUse(request, isPreFlight); + List allowHeaders = checkHeaders(config, requestHeaders); + + if (allowOrigin == null || allowMethods == null || (isPreFlight && allowHeaders == null)) { + handleInvalidCorsRequest(response); return false; } - if (allowedOrigins.contains("*")) { - return true; - } - for (String allowedOrigin : allowedOrigins) { - if (originHeader.equalsIgnoreCase(allowedOrigin)) { - return true; - } - } - return false; - } - private boolean checkMethod(ServerHttpRequest request, List allowedMethods, boolean isPreFlight) { - HttpMethod requestMethod = isPreFlight ? - request.getHeaders().getAccessControlRequestMethod() : - request.getMethod(); - if (allowedMethods == null) { - allowedMethods = Arrays.asList(HttpMethod.GET.name()); - } - if (allowedMethods.contains("*")) { - return true; - } - for (String allowedMethod : allowedMethods) { - if (allowedMethod.equalsIgnoreCase(requestMethod.name())) { - return true; - } - } - return false; - } + HttpHeaders responseHeaders = response.getHeaders(); + responseHeaders.setAccessControlAllowOrigin(allowOrigin); + responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); - private boolean checkHeaders(ServerHttpRequest request, List allowedHeaders, boolean isPreFlight) { - List requestHeaders = isPreFlight ? request.getHeaders().getAccessControlRequestHeaders() : - new ArrayList(request.getHeaders().keySet()); - if ((allowedHeaders != null) && allowedHeaders.contains("*")) { - return true; + if (isPreFlight) { + responseHeaders.setAccessControlAllowMethods(allowMethods); } - for (String requestHeader : requestHeaders) { - if (!HttpHeaders.ORIGIN.equals(requestHeader)) { - requestHeader = requestHeader.trim(); - boolean found = false; - if (allowedHeaders != null) { - for (String header : allowedHeaders) { - if (requestHeader.equalsIgnoreCase(header)) { - found = true; - break; - } - } - } - if (!found) { - return false; - } - } + + if (isPreFlight && !allowHeaders.isEmpty()) { + responseHeaders.setAccessControlAllowHeaders(allowHeaders); } + + if (!CollectionUtils.isEmpty(config.getExposedHeaders())) { + responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders()); + } + + if (Boolean.TRUE.equals(config.getAllowCredentials())) { + responseHeaders.setAccessControlAllowCredentials(true); + } + + if (isPreFlight && config.getMaxAge() != null) { + responseHeaders.setAccessControlMaxAge(config.getMaxAge()); + } + return true; } - private void setAllowOrigin(ServerHttpRequest request, ServerHttpResponse response, - List allowedOrigins, Boolean allowCredentials) { - - String origin = request.getHeaders().getOrigin(); - if (allowedOrigins.contains("*") && (allowCredentials == null || !allowCredentials)) { - response.getHeaders().setAccessControlAllowOrigin("*"); - return; - } - response.getHeaders().setAccessControlAllowOrigin(origin); - response.getHeaders().add(HttpHeaders.VARY, HttpHeaders.ORIGIN); + /** + * Check the origin and determine the origin for the response. The default + * implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected String checkOrigin(CorsConfiguration config, String requestOrigin) { + return config.checkOrigin(requestOrigin); } - private void setAllowMethods(ServerHttpRequest request, ServerHttpResponse response, - List allowedMethods) { - - if (allowedMethods == null) { - allowedMethods = Arrays.asList(HttpMethod.GET.name()); - } - if (allowedMethods.contains("*")) { - HttpMethod method = request.getHeaders().getAccessControlRequestMethod(); - response.getHeaders().setAccessControlAllowMethods(Arrays.asList(method)); - } - else { - List methods = new ArrayList(); - for (String method : allowedMethods) { - methods.add(HttpMethod.valueOf(method)); - } - response.getHeaders().setAccessControlAllowMethods(methods); - } + /** + * Check the HTTP method and determine the methods for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected List checkMethods(CorsConfiguration config, HttpMethod requestMethod) { + return config.checkHttpMethod(requestMethod); } - private void setAllowHeadersHeader(ServerHttpRequest request, ServerHttpResponse response, - List allowedHeaders) { - if ((allowedHeaders != null) && !allowedHeaders.isEmpty()) { - List requestHeaders = request.getHeaders().getAccessControlRequestHeaders(); - boolean matchAll = allowedHeaders.contains("*"); - List matchingHeaders = new ArrayList(); - for (String requestHeader : requestHeaders) { - for (String header : allowedHeaders) { - requestHeader = requestHeader.trim(); - if (matchAll || requestHeader.equalsIgnoreCase(header)) { - matchingHeaders.add(requestHeader); - break; - } - } - } - if (!matchingHeaders.isEmpty()) { - response.getHeaders().setAccessControlAllowHeaders(matchingHeaders); - } - } + private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) { + return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod()); } - private void setExposeHeadersHeader(ServerHttpResponse response, List exposedHeaders) { - if ((exposedHeaders != null) && !exposedHeaders.isEmpty()) { - response.getHeaders().setAccessControlExposeHeaders(exposedHeaders); - } + /** + * Check the headers and determine the headers for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected List checkHeaders(CorsConfiguration config, List requestHeaders) { + return config.checkHeaders(requestHeaders); } - private void setAllowCredentials(ServerHttpResponse response, Boolean allowCredentials) { - if ((allowCredentials != null) && allowCredentials) { - response.getHeaders().setAccessControlAllowCredentials(allowCredentials); - } + private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) { + HttpHeaders headers = request.getHeaders(); + return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList(headers.keySet())); } - private void setMaxAgeHeader(ServerHttpResponse response, Long maxAge) { - if (maxAge != null) { - response.getHeaders().setAccessControlMaxAge(maxAge); - } + /** + * Invoked when one of the CORS checks failed. + * The default implementation sets the response status to 403 and writes + * "Invalid CORS request" to the response. + */ + protected void handleInvalidCorsRequest(ServerHttpResponse response) throws IOException { + response.setStatusCode(HttpStatus.FORBIDDEN); + response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); } } diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java new file mode 100644 index 0000000000..081290c9e0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * Test case for {@link CorsConfiguration}. + * + * @author Sebastien Deleuze + */ +public class CorsConfigurationTests { + + private CorsConfiguration config; + + @Before + public void setup() { + config = new CorsConfiguration(); + } + + @Test + public void checkOriginAllowed() { + config.setAllowedOrigins(Arrays.asList("*")); + assertEquals("*", config.checkOrigin("http://domain.com")); + config.setAllowCredentials(true); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + config.setAllowedOrigins(Arrays.asList("http://domain.com")); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + config.setAllowCredentials(false); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + } + + @Test + public void checkOriginNotAllowed() { + assertNull(config.checkOrigin(null)); + assertNull(config.checkOrigin("http://domain.com")); + config.addAllowedOrigin("*"); + assertNull(config.checkOrigin(null)); + config.setAllowedOrigins(Arrays.asList("http://domain1.com")); + assertNull(config.checkOrigin("http://domain2.com")); + config.setAllowedOrigins(new ArrayList<>()); + assertNull(config.checkOrigin("http://domain.com")); + } + + @Test + public void checkMethodAllowed() { + assertEquals(Arrays.asList(HttpMethod.GET), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("GET"); + assertEquals(Arrays.asList(HttpMethod.GET), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("POST"); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.GET)); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.POST)); + } + + @Test + public void checkMethodNotAllowed() { + assertNull(config.checkHttpMethod(null)); + assertNull(config.checkHttpMethod(HttpMethod.DELETE)); + config.setAllowedMethods(new ArrayList<>()); + assertNull(config.checkHttpMethod(HttpMethod.HEAD)); + } + + @Test + public void checkHeadersAllowed() { + assertEquals(Collections.emptyList(), config.checkHeaders(Collections.emptyList())); + config.addAllowedHeader("header1"); + config.addAllowedHeader("header2"); + assertEquals(Arrays.asList("header1"), config.checkHeaders(Arrays.asList("header1"))); + assertEquals(Arrays.asList("header1", "header2"), config.checkHeaders(Arrays.asList("header1", "header2"))); + assertEquals(Arrays.asList("header1", "header2"), config.checkHeaders(Arrays.asList("header1", "header2", "header3"))); + } + + @Test + public void checkHeadersNotAllowed() { + assertNull(config.checkHeaders(null)); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.addAllowedHeader("header2"); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.setAllowedHeaders(new ArrayList<>()); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java index e0ee8efedf..bdcc188035 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -243,7 +243,7 @@ public class DefaultCorsProcessorTests { } @Test - public void preflightRequestCrendentialsWithOriginWildcard() throws Exception { + public void preflightRequestCredentialsWithOriginWildcard() throws Exception { this.request.setMethod(HttpMethod.OPTIONS.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 2c28728495..615ae81e9d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -35,7 +35,6 @@ import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.web.HttpRequestHandler; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.support.StaticWebApplicationContext; -import org.springframework.web.cors.CorsUtils; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.support.WebContentGenerator; diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java index a737bc410d..55b10a7504 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.http.HttpHeaders; +import org.springframework.util.CollectionUtils; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.cors.CorsConfiguration; import org.springframework.mock.web.test.MockHttpServletRequest; @@ -101,9 +102,9 @@ public class CrossOriginTests { assertNotNull(config); assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); - assertTrue(config.isAllowCredentials()); + assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertNull(config.getExposedHeaders()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @@ -119,7 +120,7 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"header1", "header2"}, config.getAllowedHeaders().toArray()); assertArrayEquals(new String[]{"header3", "header4"}, config.getExposedHeaders().toArray()); assertEquals(new Long(123), config.getMaxAge()); - assertEquals(false, config.isAllowCredentials()); + assertEquals(false, config.getAllowCredentials()); } @Test @@ -144,9 +145,9 @@ public class CrossOriginTests { assertNotNull(config); assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); - assertTrue(config.isAllowCredentials()); + assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertNull(config.getExposedHeaders()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @@ -163,8 +164,8 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertTrue(config.isAllowCredentials()); - assertNull(config.getExposedHeaders()); + assertTrue(config.getAllowCredentials()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); } @@ -180,8 +181,8 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertTrue(config.isAllowCredentials()); - assertNull(config.getExposedHeaders()); + assertTrue(config.getAllowCredentials()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); }