diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index 180ef8576d..a54112db9b 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -49,6 +49,8 @@ public class ServletServerHttpResponse implements ServerHttpResponse { private boolean headersWritten = false; + private boolean bodyUsed = false; + /** * Construct a new instance of the ServletServerHttpResponse based on the given {@link HttpServletResponse}. @@ -80,6 +82,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @Override public OutputStream getBody() throws IOException { + this.bodyUsed = true; writeHeaders(); return this.servletResponse.getOutputStream(); } @@ -87,7 +90,9 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @Override public void flush() throws IOException { writeHeaders(); - this.servletResponse.flushBuffer(); + if (this.bodyUsed) { + this.servletResponse.flushBuffer(); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java index d4fabd101f..c01b030a81 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java @@ -21,9 +21,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.springframework.http.HttpMethod; + /** - * Represents the CORS configuration that stores various properties used to check if a - * CORS request is allowed and to generate CORS response headers. + * A container for CORS configuration also providing methods to check actual or + * or requested origin, HTTP method, and headers. * * @author Sebastien Deleuze * @author Rossen Stoyanchev @@ -45,130 +47,79 @@ public class CorsConfiguration { private Long maxAge; + /** + * Default constructor. + */ public CorsConfiguration() { } - public CorsConfiguration(CorsConfiguration config) { - if (config.allowedOrigins != null) { - this.allowedOrigins = new ArrayList(config.allowedOrigins); - } - if (config.allowCredentials != null) { - this.allowCredentials = config.allowCredentials; - } - if (config.exposedHeaders != null) { - this.exposedHeaders = new ArrayList(config.exposedHeaders); - } - if (config.allowedMethods != null) { - this.allowedMethods = new ArrayList(config.allowedMethods); - } - if (config.allowedHeaders != null) { - this.allowedHeaders = new ArrayList(config.allowedHeaders); - } - if (config.maxAge != null) { - this.maxAge = config.maxAge; - } - } - - public CorsConfiguration combine(CorsConfiguration other) { - CorsConfiguration config = new CorsConfiguration(this); - - if (other.getAllowedOrigins() != null) { - config.setAllowedOrigins(other.getAllowedOrigins()); - } - if (other.getAllowedMethods() != null) { - config.setAllowedMethods(other.getAllowedMethods()); - } - if (other.getAllowedHeaders() != null) { - config.setAllowedHeaders(other.getAllowedHeaders()); - } - if (other.getExposedHeaders() != null) { - config.setExposedHeaders(other.getExposedHeaders()); - } - if (other.getMaxAge() != null) { - config.setMaxAge(other.getMaxAge()); - } - if (other.isAllowCredentials() != null) { - config.setAllowCredentials(other.isAllowCredentials()); - } - return config; + /** + * Configure origins to allow, e.g. "http://domain1.com". The special value + * "*" allows all domains. + *

By default this is not set. + */ + public void setAllowedOrigins(List origins) { + this.allowedOrigins = origins; } /** - * @see #setAllowedOrigins(java.util.List) + * Add an origin to allow. */ - public List getAllowedOrigins() { - if (this.allowedOrigins != null) { - return this.allowedOrigins.contains("*") ? Arrays.asList("*") : Collections.unmodifiableList(this.allowedOrigins); - } - return null; - } - - /** - * Set allowed allowedOrigins that will define Access-Control-Allow-Origin response - * header values (mandatory). For example "http://domain1.com", "http://domain2.com" ... - * "*" means that all domains are allowed. - */ - public void setAllowedOrigins(List allowedOrigins) { - this.allowedOrigins = allowedOrigins; - } - - /** - * @see #setAllowedOrigins(java.util.List) - */ - public void addAllowedOrigin(String allowedOrigin) { + public void addAllowedOrigin(String origin) { if (this.allowedOrigins == null) { this.allowedOrigins = new ArrayList(); } - this.allowedOrigins.add(allowedOrigin); + this.allowedOrigins.add(origin); } /** - * @see #setAllowedMethods(java.util.List) + * Return the configured origins to allow, possibly {@code null}. */ - public List getAllowedMethods() { - return this.allowedMethods == null ? null : Collections.unmodifiableList(this.allowedMethods); + public List getAllowedOrigins() { + return this.allowedOrigins; } /** - * Set allow methods that will define Access-Control-Allow-Methods response header - * values. For example "GET", "POST", "PUT" ... "*" means that all methods requested - * by the client are allowed. If not set, allowed method is set to "GET". - * + * Configure HTTP methods to allow, e.g. "GET", "POST", "PUT". The special + * value "*" allows all method. When not set only "GET is allowed. + *

By default this is not set. */ - public void setAllowedMethods(List allowedMethods) { - this.allowedMethods = allowedMethods; + public void setAllowedMethods(List methods) { + this.allowedMethods = methods; } /** - * @see #setAllowedMethods(java.util.List) + * Add an HTTP method to allow. */ - public void addAllowedMethod(String allowedMethod) { + public void addAllowedMethod(String method) { if (this.allowedMethods == null) { this.allowedMethods = new ArrayList(); } - this.allowedMethods.add(allowedMethod); + this.allowedMethods.add(method); } /** - * @see #setAllowedHeaders(java.util.List) + * Return the allowed HTTP methods, possibly {@code null} in which case only + * HTTP GET is allowed. */ - public List getAllowedHeaders() { - return this.allowedHeaders == null ? null : Collections.unmodifiableList(this.allowedHeaders); + public List getAllowedMethods() { + return this.allowedMethods; } /** - * Set a list of request headers that will define Access-Control-Allow-Methods response - * header values. If a header field name is one of the following, it is not required - * to be listed: Cache-Control, Content-Language, Expires, Last-Modified, Pragma. - * "*" means that all headers asked by the client will be allowed. + * Configure the list of headers that a pre-flight request can list as allowed + * for use during an actual request. The special value of "*" allows actual + * requests to send any header. A header name is not required to be listed if + * it is one of: Cache-Control, Content-Language, Expires, Last-Modified, Pragma. + *

By default this is not set. */ public void setAllowedHeaders(List allowedHeaders) { this.allowedHeaders = allowedHeaders; } /** - * @see #setAllowedHeaders(java.util.List) + * Add one actual request header to allow. */ public void addAllowedHeader(String allowedHeader) { if (this.allowedHeaders == null) { @@ -178,23 +129,24 @@ public class CorsConfiguration { } /** - * @see #setExposedHeaders(java.util.List) + * Return the allowed actual request headers, possibly {@code null}. */ - public List getExposedHeaders() { - return this.exposedHeaders == null ? null : Collections.unmodifiableList(this.exposedHeaders); + public List getAllowedHeaders() { + return this.allowedHeaders; } /** - * Set a list of response headers other than simple headers that the resource might use - * and can be exposed. Simple response headers are: Cache-Control, Content-Language, - * Content-Type, Expires, Last-Modified, Pragma. + * Configure the list of response headers other than simple headers (i.e. + * Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, + * Pragma) that an actual response might have and can be exposed. + *

By default this is not set. */ public void setExposedHeaders(List exposedHeaders) { this.exposedHeaders = exposedHeaders; } /** - * @see #setExposedHeaders(java.util.List) + * Add a single response header to expose. */ public void addExposedHeader(String exposedHeader) { if (this.exposedHeaders == null) { @@ -204,33 +156,131 @@ public class CorsConfiguration { } /** - * @see #setAllowCredentials(Boolean) + * Return the configured response headers to expose, possibly {@code null}. */ - public Boolean isAllowCredentials() { - return this.allowCredentials; + public List getExposedHeaders() { + return this.exposedHeaders; } /** - * Indicates whether the resource supports user credentials. - * Set the value of Access-Control-Allow-Credentials response header. + * Whether user credentials are supported. + *

By default this is not set (i.e. user credentials not supported). */ public void setAllowCredentials(Boolean allowCredentials) { this.allowCredentials = allowCredentials; } /** - * @see #setMaxAge(Long) + * Return the configured allowCredentials, possibly {@code null}. */ - public Long getMaxAge() { - return maxAge; + public Boolean getAllowCredentials() { + return this.allowCredentials; } /** - * Indicates how long (seconds) the results of a preflight request can be cached - * in a preflight result cache. + * Configure how long, in seconds, the response from a pre-flight request + * can be cached by clients. + *

By default this is not set. */ public void setMaxAge(Long maxAge) { this.maxAge = maxAge; } + /** + * Return the configure maxAge value, possibly {@code null}. + */ + public Long getMaxAge() { + return maxAge; + } + + + /** + * Check the origin of the request against the configured allowed origins. + * @param requestOrigin the origin to check. + * @return the origin to use for the response, possibly {@code null} which + * means the request origin is not allowed. + */ + public String checkOrigin(String requestOrigin) { + if (requestOrigin == null) { + return null; + } + List allowedOrigins = this.allowedOrigins == null ? + new ArrayList() : this.allowedOrigins; + if (allowedOrigins.contains("*")) { + if (this.allowCredentials == null || !this.allowCredentials) { + return "*"; + } else { + return requestOrigin; + } + } + for (String allowedOrigin : allowedOrigins) { + if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { + return requestOrigin; + } + } + return null; + } + + /** + * Check the request HTTP method (or the method from the + * Access-Control-Request-Method header on a pre-flight request) against the + * configured allowed methods. + * @param requestMethod the HTTP method to check. + * @return the list of HTTP methods to list in the response of a pre-flight + * request, or {@code null} if the requestMethod is not allowed. + */ + public List checkHttpMethod(HttpMethod requestMethod) { + if (requestMethod == null) { + return null; + } + List allowedMethods = this.allowedMethods == null ? + new ArrayList() : this.allowedMethods; + if (allowedMethods.contains("*")) { + return Arrays.asList(requestMethod); + } + if (allowedMethods.isEmpty()) { + allowedMethods.add(HttpMethod.GET.name()); + } + List result = new ArrayList(allowedMethods.size()); + boolean allowed = false; + for (String method : allowedMethods) { + if (method.equals(requestMethod.name())) { + allowed = true; + } + result.add(HttpMethod.valueOf(method)); + } + return allowed ? result : null; + } + + /** + * Check the request headers (or the headers listed in the + * Access-Control-Request-Headers of a pre-flight request) against the + * configured allowed headers. + * @param requestHeaders the headers to check. + * @return the list of allowed headers to list in the response of a pre-flight + * request, or {@code null} if a requestHeader is not allowed. + */ + public List checkHeaders(List requestHeaders) { + if (requestHeaders == null) { + return null; + } + if (requestHeaders.isEmpty()) { + return Collections.emptyList(); + } + List allowedHeaders = this.allowedHeaders == null ? + new ArrayList() : this.allowedHeaders; + boolean allowAnyHeader = allowedHeaders.contains("*"); + List result = new ArrayList(); + for (String requestHeader : requestHeaders) { + requestHeader = requestHeader.trim(); + for (String allowedHeader : allowedHeaders) { + if (allowAnyHeader || requestHeader.equalsIgnoreCase(allowedHeader)) { + result.add(requestHeader); + break; + } + } + } + return result.isEmpty() ? null : result; + } + } diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index e441604b1b..6a141897ba 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -19,7 +19,6 @@ package org.springframework.web.cors; import java.io.IOException; import java.nio.charset.Charset; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import javax.servlet.http.HttpServletRequest; @@ -36,12 +35,14 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Default implementation of {@link CorsProcessor}, as defined by the * CORS W3C recommandation. * * @author Sebastien Deleuze + * @author Rossen Stoyanhcev * @since 4.2 */ public class DefaultCorsProcessor implements CorsProcessor { @@ -56,23 +57,19 @@ public class DefaultCorsProcessor implements CorsProcessor { public boolean processPreFlightRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { - ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + Assert.isTrue(CorsUtils.isPreFlightRequest(request)); + ServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - boolean isPreFlight = CorsUtils.isPreFlightRequest(request); - Assert.isTrue(isPreFlight); - if (skip(serverResponse)) { + if (responseHasCors(serverResponse)) { return true; } - if (check(serverRequest, serverResponse, config, isPreFlight)) { - setAllowOrigin(serverRequest, serverResponse, config.getAllowedOrigins(), config.isAllowCredentials()); - setAllowCredentials(serverResponse, config.isAllowCredentials()); - setAllowMethods(serverRequest, serverResponse, config.getAllowedMethods()); - setAllowHeadersHeader(serverRequest, serverResponse, config.getAllowedHeaders()); - setMaxAgeHeader(serverResponse, config.getMaxAge()); - serverResponse.close(); + ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (handleInternal(serverRequest, serverResponse, config, true)) { + serverResponse.flush(); return true; } + return false; } @@ -80,185 +77,124 @@ public class DefaultCorsProcessor implements CorsProcessor { public boolean processActualRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { - ServerHttpRequest serverRequest = new ServletServerHttpRequest(request); - ServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - boolean isPreFlight = CorsUtils.isPreFlightRequest(request); - Assert.isTrue(CorsUtils.isCorsRequest(request) && !isPreFlight); - if (skip(serverResponse)) { + Assert.isTrue(CorsUtils.isCorsRequest(request) && !CorsUtils.isPreFlightRequest(request)); + + ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); + if (responseHasCors(serverResponse)) { return true; } - if (check(serverRequest, serverResponse, config, isPreFlight)) { - setAllowOrigin(serverRequest, serverResponse, config.getAllowedOrigins(), config.isAllowCredentials()); - setAllowCredentials(serverResponse, config.isAllowCredentials()); - setExposeHeadersHeader(serverResponse, config.getExposedHeaders()); - serverResponse.close(); + ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); + if (handleInternal(serverRequest, serverResponse, config, false)) { + serverResponse.flush(); return true; } + return false; } - private boolean skip(ServerHttpResponse response) { - if (hasAllowOriginHeader(response)) { - logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); - return true; - } - return false; - } - - private boolean check(ServerHttpRequest request, ServerHttpResponse response, - CorsConfiguration config, boolean isPreFlight) throws IOException { - - if (!checkOrigin(request, config.getAllowedOrigins()) || - !checkMethod(request, config.getAllowedMethods(), isPreFlight) || - !checkHeaders(request, config.getAllowedHeaders(), isPreFlight)) { - response.setStatusCode(HttpStatus.FORBIDDEN); - response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); - return false; - } - return true; - } - - private boolean hasAllowOriginHeader(ServerHttpResponse response) { - boolean hasCorsResponseHeaders = false; + private boolean responseHasCors(ServerHttpResponse response) { + boolean hasAllowOrigin = false; try { - // Perhaps a CORS Filter has already added this? - hasCorsResponseHeaders = response.getHeaders().getAccessControlAllowOrigin() != null; + hasAllowOrigin = (response.getHeaders().getAccessControlAllowOrigin() != null); } catch (NullPointerException npe) { - // See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 + // SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 } - return hasCorsResponseHeaders; + if (hasAllowOrigin) { + logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); + } + return hasAllowOrigin; } - private boolean checkOrigin(ServerHttpRequest request, List allowedOrigins) { - String originHeader = request.getHeaders().getOrigin(); - if (originHeader == null || allowedOrigins == null) { + protected boolean handleInternal(ServerHttpRequest request, ServerHttpResponse response, + CorsConfiguration config, boolean isPreFlight) throws IOException { + + String requestOrigin = request.getHeaders().getOrigin(); + String allowOrigin = checkOrigin(config, requestOrigin); + + HttpMethod requestMethod = getMethodToUse(request, isPreFlight); + List allowMethods = checkMethods(config, requestMethod); + + List requestHeaders = getHeadersToUse(request, isPreFlight); + List allowHeaders = checkHeaders(config, requestHeaders); + + if (allowOrigin == null || allowMethods == null || (isPreFlight && allowHeaders == null)) { + handleInvalidCorsRequest(response); return false; } - if (allowedOrigins.contains("*")) { - return true; - } - for (String allowedOrigin : allowedOrigins) { - if (originHeader.equalsIgnoreCase(allowedOrigin)) { - return true; - } - } - return false; - } - private boolean checkMethod(ServerHttpRequest request, List allowedMethods, boolean isPreFlight) { - HttpMethod requestMethod = isPreFlight ? - request.getHeaders().getAccessControlRequestMethod() : - request.getMethod(); - if (allowedMethods == null) { - allowedMethods = Arrays.asList(HttpMethod.GET.name()); - } - if (allowedMethods.contains("*")) { - return true; - } - for (String allowedMethod : allowedMethods) { - if (allowedMethod.equalsIgnoreCase(requestMethod.name())) { - return true; - } - } - return false; - } + HttpHeaders responseHeaders = response.getHeaders(); + responseHeaders.setAccessControlAllowOrigin(allowOrigin); + responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); - private boolean checkHeaders(ServerHttpRequest request, List allowedHeaders, boolean isPreFlight) { - List requestHeaders = isPreFlight ? request.getHeaders().getAccessControlRequestHeaders() : - new ArrayList(request.getHeaders().keySet()); - if ((allowedHeaders != null) && allowedHeaders.contains("*")) { - return true; + if (isPreFlight) { + responseHeaders.setAccessControlAllowMethods(allowMethods); } - for (String requestHeader : requestHeaders) { - if (!HttpHeaders.ORIGIN.equals(requestHeader)) { - requestHeader = requestHeader.trim(); - boolean found = false; - if (allowedHeaders != null) { - for (String header : allowedHeaders) { - if (requestHeader.equalsIgnoreCase(header)) { - found = true; - break; - } - } - } - if (!found) { - return false; - } - } + + if (isPreFlight && !allowHeaders.isEmpty()) { + responseHeaders.setAccessControlAllowHeaders(allowHeaders); } + + if (!CollectionUtils.isEmpty(config.getExposedHeaders())) { + responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders()); + } + + if (Boolean.TRUE.equals(config.getAllowCredentials())) { + responseHeaders.setAccessControlAllowCredentials(true); + } + + if (isPreFlight && config.getMaxAge() != null) { + responseHeaders.setAccessControlMaxAge(config.getMaxAge()); + } + return true; } - private void setAllowOrigin(ServerHttpRequest request, ServerHttpResponse response, - List allowedOrigins, Boolean allowCredentials) { - - String origin = request.getHeaders().getOrigin(); - if (allowedOrigins.contains("*") && (allowCredentials == null || !allowCredentials)) { - response.getHeaders().setAccessControlAllowOrigin("*"); - return; - } - response.getHeaders().setAccessControlAllowOrigin(origin); - response.getHeaders().add(HttpHeaders.VARY, HttpHeaders.ORIGIN); + /** + * Check the origin and determine the origin for the response. The default + * implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected String checkOrigin(CorsConfiguration config, String requestOrigin) { + return config.checkOrigin(requestOrigin); } - private void setAllowMethods(ServerHttpRequest request, ServerHttpResponse response, - List allowedMethods) { - - if (allowedMethods == null) { - allowedMethods = Arrays.asList(HttpMethod.GET.name()); - } - if (allowedMethods.contains("*")) { - HttpMethod method = request.getHeaders().getAccessControlRequestMethod(); - response.getHeaders().setAccessControlAllowMethods(Arrays.asList(method)); - } - else { - List methods = new ArrayList(); - for (String method : allowedMethods) { - methods.add(HttpMethod.valueOf(method)); - } - response.getHeaders().setAccessControlAllowMethods(methods); - } + /** + * Check the HTTP method and determine the methods for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected List checkMethods(CorsConfiguration config, HttpMethod requestMethod) { + return config.checkHttpMethod(requestMethod); } - private void setAllowHeadersHeader(ServerHttpRequest request, ServerHttpResponse response, - List allowedHeaders) { - if ((allowedHeaders != null) && !allowedHeaders.isEmpty()) { - List requestHeaders = request.getHeaders().getAccessControlRequestHeaders(); - boolean matchAll = allowedHeaders.contains("*"); - List matchingHeaders = new ArrayList(); - for (String requestHeader : requestHeaders) { - for (String header : allowedHeaders) { - requestHeader = requestHeader.trim(); - if (matchAll || requestHeader.equalsIgnoreCase(header)) { - matchingHeaders.add(requestHeader); - break; - } - } - } - if (!matchingHeaders.isEmpty()) { - response.getHeaders().setAccessControlAllowHeaders(matchingHeaders); - } - } + private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) { + return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod()); } - private void setExposeHeadersHeader(ServerHttpResponse response, List exposedHeaders) { - if ((exposedHeaders != null) && !exposedHeaders.isEmpty()) { - response.getHeaders().setAccessControlExposeHeaders(exposedHeaders); - } + /** + * Check the headers and determine the headers for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link org.springframework.web.cors.CorsConfiguration#checkOrigin(String)} + */ + protected List checkHeaders(CorsConfiguration config, List requestHeaders) { + return config.checkHeaders(requestHeaders); } - private void setAllowCredentials(ServerHttpResponse response, Boolean allowCredentials) { - if ((allowCredentials != null) && allowCredentials) { - response.getHeaders().setAccessControlAllowCredentials(allowCredentials); - } + private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) { + HttpHeaders headers = request.getHeaders(); + return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList(headers.keySet())); } - private void setMaxAgeHeader(ServerHttpResponse response, Long maxAge) { - if (maxAge != null) { - response.getHeaders().setAccessControlMaxAge(maxAge); - } + /** + * Invoked when one of the CORS checks failed. + * The default implementation sets the response status to 403 and writes + * "Invalid CORS request" to the response. + */ + protected void handleInvalidCorsRequest(ServerHttpResponse response) throws IOException { + response.setStatusCode(HttpStatus.FORBIDDEN); + response.getBody().write("Invalid CORS request".getBytes(UTF8_CHARSET)); } } diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java new file mode 100644 index 0000000000..081290c9e0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpMethod; + +/** + * Test case for {@link CorsConfiguration}. + * + * @author Sebastien Deleuze + */ +public class CorsConfigurationTests { + + private CorsConfiguration config; + + @Before + public void setup() { + config = new CorsConfiguration(); + } + + @Test + public void checkOriginAllowed() { + config.setAllowedOrigins(Arrays.asList("*")); + assertEquals("*", config.checkOrigin("http://domain.com")); + config.setAllowCredentials(true); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + config.setAllowedOrigins(Arrays.asList("http://domain.com")); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + config.setAllowCredentials(false); + assertEquals("http://domain.com", config.checkOrigin("http://domain.com")); + } + + @Test + public void checkOriginNotAllowed() { + assertNull(config.checkOrigin(null)); + assertNull(config.checkOrigin("http://domain.com")); + config.addAllowedOrigin("*"); + assertNull(config.checkOrigin(null)); + config.setAllowedOrigins(Arrays.asList("http://domain1.com")); + assertNull(config.checkOrigin("http://domain2.com")); + config.setAllowedOrigins(new ArrayList<>()); + assertNull(config.checkOrigin("http://domain.com")); + } + + @Test + public void checkMethodAllowed() { + assertEquals(Arrays.asList(HttpMethod.GET), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("GET"); + assertEquals(Arrays.asList(HttpMethod.GET), config.checkHttpMethod(HttpMethod.GET)); + config.addAllowedMethod("POST"); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.GET)); + assertEquals(Arrays.asList(HttpMethod.GET, HttpMethod.POST), config.checkHttpMethod(HttpMethod.POST)); + } + + @Test + public void checkMethodNotAllowed() { + assertNull(config.checkHttpMethod(null)); + assertNull(config.checkHttpMethod(HttpMethod.DELETE)); + config.setAllowedMethods(new ArrayList<>()); + assertNull(config.checkHttpMethod(HttpMethod.HEAD)); + } + + @Test + public void checkHeadersAllowed() { + assertEquals(Collections.emptyList(), config.checkHeaders(Collections.emptyList())); + config.addAllowedHeader("header1"); + config.addAllowedHeader("header2"); + assertEquals(Arrays.asList("header1"), config.checkHeaders(Arrays.asList("header1"))); + assertEquals(Arrays.asList("header1", "header2"), config.checkHeaders(Arrays.asList("header1", "header2"))); + assertEquals(Arrays.asList("header1", "header2"), config.checkHeaders(Arrays.asList("header1", "header2", "header3"))); + } + + @Test + public void checkHeadersNotAllowed() { + assertNull(config.checkHeaders(null)); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.addAllowedHeader("header2"); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + config.setAllowedHeaders(new ArrayList<>()); + assertNull(config.checkHeaders(Arrays.asList("header1"))); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java index e0ee8efedf..bdcc188035 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -243,7 +243,7 @@ public class DefaultCorsProcessorTests { } @Test - public void preflightRequestCrendentialsWithOriginWildcard() throws Exception { + public void preflightRequestCredentialsWithOriginWildcard() throws Exception { this.request.setMethod(HttpMethod.OPTIONS.name()); this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 2c28728495..615ae81e9d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -35,7 +35,6 @@ import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.web.HttpRequestHandler; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.support.StaticWebApplicationContext; -import org.springframework.web.cors.CorsUtils; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.support.WebContentGenerator; diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java index a737bc410d..55b10a7504 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.http.HttpHeaders; +import org.springframework.util.CollectionUtils; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.cors.CorsConfiguration; import org.springframework.mock.web.test.MockHttpServletRequest; @@ -101,9 +102,9 @@ public class CrossOriginTests { assertNotNull(config); assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); - assertTrue(config.isAllowCredentials()); + assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertNull(config.getExposedHeaders()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @@ -119,7 +120,7 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"header1", "header2"}, config.getAllowedHeaders().toArray()); assertArrayEquals(new String[]{"header3", "header4"}, config.getExposedHeaders().toArray()); assertEquals(new Long(123), config.getMaxAge()); - assertEquals(false, config.isAllowCredentials()); + assertEquals(false, config.getAllowCredentials()); } @Test @@ -144,9 +145,9 @@ public class CrossOriginTests { assertNotNull(config); assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); - assertTrue(config.isAllowCredentials()); + assertTrue(config.getAllowCredentials()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertNull(config.getExposedHeaders()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertEquals(new Long(1800), config.getMaxAge()); } @@ -163,8 +164,8 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertTrue(config.isAllowCredentials()); - assertNull(config.getExposedHeaders()); + assertTrue(config.getAllowCredentials()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); } @@ -180,8 +181,8 @@ public class CrossOriginTests { assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray()); - assertTrue(config.isAllowCredentials()); - assertNull(config.getExposedHeaders()); + assertTrue(config.getAllowCredentials()); + assertTrue(CollectionUtils.isEmpty(config.getExposedHeaders())); assertNull(config.getMaxAge()); }