diff --git a/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/CorsIntegrationTests.java b/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/CorsIntegrationTests.java index 24cdd5455..65f2ed791 100755 --- a/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/CorsIntegrationTests.java +++ b/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/CorsIntegrationTests.java @@ -15,6 +15,7 @@ */ package org.springframework.data.rest.webmvc.jpa; +import static org.assertj.core.api.Assertions.*; import static org.hamcrest.Matchers.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; @@ -25,10 +26,12 @@ import org.springframework.context.annotation.Configuration; import org.springframework.data.rest.tests.AbstractWebIntegrationTests; import org.springframework.data.rest.webmvc.BasePathAwareController; import org.springframework.data.rest.webmvc.RepositoryRestController; +import org.springframework.data.rest.webmvc.RepositoryRestHandlerMapping; import org.springframework.data.rest.webmvc.config.RepositoryRestConfigurer; import org.springframework.hateoas.Link; import org.springframework.hateoas.LinkRelation; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.test.context.ContextConfiguration; import org.springframework.web.bind.annotation.CrossOrigin; @@ -83,10 +86,15 @@ public class CorsIntegrationTests extends AbstractWebIntegrationTests { Link findItems = client.discoverUnique(LinkRelation.of("items")); // Preflight request - mvc.perform(options(findItems.expand().getHref()).header(HttpHeaders.ORIGIN, "http://far.far.example") - .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")) // + String header = mvc + .perform(options(findItems.expand().getHref()).header(HttpHeaders.ORIGIN, "http://far.far.example") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")) // .andExpect(status().isOk()) // - .andExpect(header().string(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS, "GET,HEAD,POST")); + .andReturn().getResponse().getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS); + + assertThat(header.split(",")) + .containsExactlyInAnyOrderElementsOf( + RepositoryRestHandlerMapping.DEFAULT_ALLOWED_METHODS.map(HttpMethod::name)); } @Test // DATAREST-573 diff --git a/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/LocalConfigCorsIntegrationTests.java b/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/LocalConfigCorsIntegrationTests.java index b3eb6b9a3..2af3e7085 100755 --- a/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/LocalConfigCorsIntegrationTests.java +++ b/spring-data-rest-tests/spring-data-rest-tests-jpa/src/test/java/org/springframework/data/rest/webmvc/jpa/LocalConfigCorsIntegrationTests.java @@ -15,16 +15,19 @@ */ package org.springframework.data.rest.webmvc.jpa; +import static org.assertj.core.api.Assertions.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; import org.junit.Test; import org.springframework.context.annotation.Bean; import org.springframework.data.rest.tests.AbstractWebIntegrationTests; +import org.springframework.data.rest.webmvc.RepositoryRestHandlerMapping; import org.springframework.data.rest.webmvc.config.RepositoryRestConfigurer; import org.springframework.hateoas.Link; import org.springframework.hateoas.LinkRelation; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.test.context.ContextConfiguration; /** @@ -53,9 +56,14 @@ public class LocalConfigCorsIntegrationTests extends AbstractWebIntegrationTests Link findItems = client.discoverUnique(LinkRelation.of("items")); // Preflight request - mvc.perform(options(findItems.expand().getHref()).header(HttpHeaders.ORIGIN, "http://far.far.example") - .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")) // + String header = mvc + .perform(options(findItems.expand().getHref()).header(HttpHeaders.ORIGIN, "http://far.far.example") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")) // .andExpect(status().isOk()) // - .andExpect(header().string(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS, "GET,HEAD,POST")); + .andReturn().getResponse().getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS); + + assertThat(header.split(",")) + .containsExactlyInAnyOrderElementsOf( + RepositoryRestHandlerMapping.DEFAULT_ALLOWED_METHODS.map(HttpMethod::name)); } } diff --git a/spring-data-rest-webmvc/src/main/java/org/springframework/data/rest/webmvc/RepositoryRestHandlerMapping.java b/spring-data-rest-webmvc/src/main/java/org/springframework/data/rest/webmvc/RepositoryRestHandlerMapping.java index 8840059b8..8e9dfca1d 100644 --- a/spring-data-rest-webmvc/src/main/java/org/springframework/data/rest/webmvc/RepositoryRestHandlerMapping.java +++ b/spring-data-rest-webmvc/src/main/java/org/springframework/data/rest/webmvc/RepositoryRestHandlerMapping.java @@ -28,10 +28,13 @@ import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.data.repository.support.Repositories; import org.springframework.data.rest.core.config.RepositoryRestConfiguration; +import org.springframework.data.rest.core.mapping.HttpMethods; import org.springframework.data.rest.core.mapping.ResourceMappings; import org.springframework.data.rest.core.mapping.ResourceMetadata; import org.springframework.data.rest.webmvc.support.JpaHelper; import org.springframework.data.util.ProxyUtils; +import org.springframework.data.util.Streamable; +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.orm.jpa.support.OpenEntityManagerInViewInterceptor; import org.springframework.util.Assert; @@ -58,6 +61,10 @@ import org.springframework.web.util.pattern.PathPatternParser; */ public class RepositoryRestHandlerMapping extends BasePathAwareHandlerMapping { + public static final HttpMethods DEFAULT_ALLOWED_METHODS = HttpMethods.none() + .and(HttpMethod.values()) + .butWithout(HttpMethod.TRACE); + private static final PathPatternParser PARSER = new PathPatternParser(); static final String EFFECTIVE_LOOKUP_PATH_ATTRIBUTE = RepositoryRestHandlerMapping.class.getName() + ".EFFECTIVE_REPOSITORY_RESOURCE_LOOKUP_PATH"; @@ -361,7 +368,7 @@ public class RepositoryRestHandlerMapping extends BasePathAwareHandlerMapping { config.addAllowedOrigin(resolveCorsAnnotationValue(origin)); } - for (RequestMethod method : annotation.methods()) { + for (HttpMethod method : getAllowedMethods(annotation)) { config.addAllowedMethod(method.name()); } @@ -392,5 +399,24 @@ public class RepositoryRestHandlerMapping extends BasePathAwareHandlerMapping { private String resolveCorsAnnotationValue(String value) { return this.embeddedValueResolver.resolveStringValue(value); } + + /** + * Returns the {@link HttpMethods} configured on the given annotation or the default methods to support. + * + * @param annotation must not be {@literal null}. + * @return + * @see #DEFAULT_ALLOWED_METHODS + */ + private static HttpMethods getAllowedMethods(CrossOrigin annotation) { + + RequestMethod[] methods = annotation.methods(); + + return methods.length == 0 + ? DEFAULT_ALLOWED_METHODS + : HttpMethods.of(Streamable.of(methods) + .map(RequestMethod::name) + .map(HttpMethod::resolve) + .toList()); + } } } diff --git a/spring-data-rest-webmvc/src/test/java/org/springframework/data/rest/webmvc/RepositoryCorsConfigurationAccessorUnitTests.java b/spring-data-rest-webmvc/src/test/java/org/springframework/data/rest/webmvc/RepositoryCorsConfigurationAccessorUnitTests.java index 6aa6fa3a8..469cce37b 100755 --- a/spring-data-rest-webmvc/src/test/java/org/springframework/data/rest/webmvc/RepositoryCorsConfigurationAccessorUnitTests.java +++ b/spring-data-rest-webmvc/src/test/java/org/springframework/data/rest/webmvc/RepositoryCorsConfigurationAccessorUnitTests.java @@ -64,7 +64,9 @@ public class RepositoryCorsConfigurationAccessorUnitTests { assertThat(configuration.getAllowCredentials()).isNull(); assertThat(configuration.getAllowedHeaders()).contains("*"); assertThat(configuration.getAllowedOrigins()).contains("*"); - assertThat(configuration.getAllowedMethods()).contains("HEAD", "GET", "POST"); + assertThat(configuration.getAllowedMethods()) + .contains("HEAD", "GET", "POST", "PUT", "PATCH", "OPTIONS") + .doesNotContain("TRACE"); assertThat(configuration.getMaxAge()).isEqualTo(1800L); }