diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index 7403938469..f964afd081 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-201/ the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.web.cors; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -121,7 +122,8 @@ public class DefaultCorsProcessor implements CorsProcessor { String allowOrigin = checkOrigin(config, requestOrigin); HttpHeaders responseHeaders = response.getHeaders(); - responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); + responseHeaders.addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); if (allowOrigin == null) { logger.debug("Rejecting CORS request because '" + requestOrigin + "' origin is not allowed"); diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java index 39ecfae87b..418d9acee0 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.cors.reactive; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.logging.Log; @@ -107,7 +108,8 @@ public class DefaultCorsProcessor implements CorsProcessor { ServerHttpResponse response = exchange.getResponse(); HttpHeaders responseHeaders = response.getHeaders(); - response.getHeaders().add(HttpHeaders.VARY, HttpHeaders.ORIGIN); + response.getHeaders().addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); String requestOrigin = request.getHeaders().getOrigin(); String allowOrigin = checkOrigin(config, requestOrigin); diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java index c929c10e4c..7c054c72d7 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.springframework.http.HttpMethod; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; +import static org.hamcrest.Matchers.contains; import static org.junit.Assert.*; /** @@ -65,7 +66,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); } @@ -90,7 +92,8 @@ public class DefaultCorsProcessorTests { assertEquals("*", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -108,7 +111,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -124,7 +128,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -136,7 +141,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -154,7 +160,8 @@ public class DefaultCorsProcessorTests { assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -166,7 +173,8 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedOrigin("*"); this.processor.processRequest(this.conf, this.request, this.response); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -178,7 +186,8 @@ public class DefaultCorsProcessorTests { this.conf.addAllowedOrigin("*"); this.processor.processRequest(this.conf, this.request, this.response); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); } @@ -192,7 +201,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); assertEquals("GET,HEAD", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); } @Test @@ -202,7 +212,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); } @@ -214,7 +225,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); } @@ -227,7 +239,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_FORBIDDEN, this.response.getStatus()); } @@ -249,7 +262,8 @@ public class DefaultCorsProcessorTests { assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,PUT", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -270,7 +284,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -289,7 +304,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -310,7 +326,8 @@ public class DefaultCorsProcessorTests { assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -329,7 +346,8 @@ public class DefaultCorsProcessorTests { assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertFalse(this.response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } @@ -345,7 +363,8 @@ public class DefaultCorsProcessorTests { this.processor.processRequest(this.conf, this.request, this.response); assertTrue(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); - assertEquals(HttpHeaders.ORIGIN, this.response.getHeader(HttpHeaders.VARY)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); } diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java index bf396e6a7b..e5500335de 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,14 +28,18 @@ import org.springframework.mock.web.test.server.MockServerWebExchange; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.server.ServerWebExchange; +import static org.hamcrest.Matchers.contains; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS; import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD; +import static org.springframework.http.HttpHeaders.ORIGIN; +import static org.springframework.http.HttpHeaders.VARY; /** * {@link DefaultCorsProcessor} tests with simple or pre-flight CORS request. @@ -63,7 +67,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); } @@ -88,7 +93,8 @@ public class DefaultCorsProcessorTests { assertEquals("*", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -106,7 +112,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -122,7 +129,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -134,7 +142,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -152,7 +161,8 @@ public class DefaultCorsProcessorTests { assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); assertTrue(response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -164,7 +174,8 @@ public class DefaultCorsProcessorTests { this.processor.process(this.conf, exchange); ServerHttpResponse response = exchange.getResponse(); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -177,7 +188,8 @@ public class DefaultCorsProcessorTests { this.processor.process(this.conf, exchange); ServerHttpResponse response = exchange.getResponse(); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); } @@ -190,7 +202,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertNull(response.getStatusCode()); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals("GET,HEAD", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); } @@ -201,7 +214,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); } @@ -213,7 +227,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); } @@ -227,7 +242,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertEquals(HttpStatus.FORBIDDEN, response.getStatusCode()); } @@ -251,7 +267,8 @@ public class DefaultCorsProcessorTests { assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertEquals("GET,PUT", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); assertFalse(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -274,7 +291,8 @@ public class DefaultCorsProcessorTests { assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); assertTrue(response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); assertEquals("true", response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -295,7 +313,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertEquals("http://domain2.com", response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -318,7 +337,8 @@ public class DefaultCorsProcessorTests { assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -339,7 +359,8 @@ public class DefaultCorsProcessorTests { assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); assertTrue(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); assertFalse(response.getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } @@ -357,7 +378,8 @@ public class DefaultCorsProcessorTests { ServerHttpResponse response = exchange.getResponse(); assertTrue(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_HEADERS)); - assertEquals(HttpHeaders.ORIGIN, response.getHeaders().getFirst(HttpHeaders.VARY)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); assertNull(response.getStatusCode()); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/GlobalCorsConfigIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/GlobalCorsConfigIntegrationTests.java index 4a4a8af482..01d614f2de 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/GlobalCorsConfigIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/GlobalCorsConfigIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,9 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.web.bind.annotation.GetMapping; @@ -34,8 +36,10 @@ import org.springframework.web.client.RestTemplate; import org.springframework.web.reactive.config.CorsRegistry; import org.springframework.web.reactive.config.WebFluxConfigurationSupport; +import static org.hamcrest.Matchers.contains; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; /** @@ -101,12 +105,22 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte assertEquals("welcome", entity.getBody()); } + @Test + public void actualRequestWithAmbiguousMapping() throws Exception { + this.headers.add(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE); + ResponseEntity entity = performGet("/ambiguous", this.headers, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin()); + } + @Test public void preFlightRequestWithCorsEnabled() throws Exception { this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); ResponseEntity entity = performOptions("/cors", this.headers, String.class); assertEquals(HttpStatus.OK, entity.getStatusCode()); assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin()); + assertThat(entity.getHeaders().getAccessControlAllowMethods(), + contains(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.POST)); } @Test @@ -133,6 +147,28 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte } } + @Test + public void preFlightRequestWithCorsRestricted() throws Exception { + this.headers.set(HttpHeaders.ORIGIN, "http://foo"); + this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + ResponseEntity entity = performOptions("/cors-restricted", this.headers, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://foo", entity.getHeaders().getAccessControlAllowOrigin()); + assertThat(entity.getHeaders().getAccessControlAllowMethods(), contains(HttpMethod.GET, HttpMethod.POST)); + } + + @Test + public void preFlightRequestWithAmbiguousMapping() throws Exception { + this.headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + ResponseEntity entity = performOptions("/ambiguous", this.headers, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin()); + assertThat(entity.getHeaders().getAccessControlAllowMethods(), contains(HttpMethod.GET)); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + assertThat(entity.getHeaders().get(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + } + @Configuration @ComponentScan(resourcePattern = "**/GlobalCorsConfigIntegrationTests*.class") @@ -141,8 +177,12 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte @Override protected void addCorsMappings(CorsRegistry registry) { - registry.addMapping("/cors-restricted").allowedOrigins("http://foo"); + registry.addMapping("/cors-restricted") + .allowedOrigins("http://foo") + .allowedMethods("GET", "POST"); registry.addMapping("/cors"); + registry.addMapping("/ambiguous") + .allowedMethods("GET", "POST"); } } @@ -163,6 +203,16 @@ public class GlobalCorsConfigIntegrationTests extends AbstractRequestMappingInte public String corsRestricted() { return "corsRestricted"; } + + @GetMapping(value = "/ambiguous", produces = MediaType.TEXT_PLAIN_VALUE) + public String ambiguous1() { + return "ambiguous"; + } + + @GetMapping(value = "/ambiguous", produces = MediaType.TEXT_HTML_VALUE) + public String ambiguous2() { + return "

ambiguous

"; + } } }