diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java b/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java index 0c3958d4c8..57c04cd302 100644 --- a/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/WebHttpHandlerBuilder.java @@ -20,8 +20,8 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; -import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.http.server.reactive.HttpHandler; @@ -118,21 +118,15 @@ public class WebHttpHandlerBuilder { // WebFilter... - Collection filters = BeanFactoryUtils.beansOfTypeIncludingAncestors( - context, WebFilter.class, true, false).values(); - - WebFilter[] sortedFilters = filters.toArray(new WebFilter[filters.size()]); - AnnotationAwareOrderComparator.sort(sortedFilters); - builder.filters(sortedFilters); + AutowiredFiltersContainer filtersContainer = new AutowiredFiltersContainer(); + context.getAutowireCapableBeanFactory().autowireBean(filtersContainer); + builder.filters(filtersContainer.getFilters()); // WebExceptionHandler... - Collection handlers = BeanFactoryUtils.beansOfTypeIncludingAncestors( - context, WebExceptionHandler.class, true, false).values(); - - WebExceptionHandler[] sortedHandlers = handlers.toArray(new WebExceptionHandler[handlers.size()]); - AnnotationAwareOrderComparator.sort(sortedHandlers); - builder.exceptionHandlers(sortedHandlers); + AutowiredExceptionHandlersContainer handlersContainer = new AutowiredExceptionHandlersContainer(); + context.getAutowireCapableBeanFactory().autowireBean(handlersContainer); + builder.exceptionHandlers(handlersContainer.getExceptionHandlers()); // WebSessionManager @@ -153,8 +147,16 @@ public class WebHttpHandlerBuilder { * @param filters the filters to add */ public WebHttpHandlerBuilder filters(WebFilter... filters) { + return filters(Arrays.asList(filters)); + } + + /** + * Add the given filters to use for processing requests. + * @param filters the filters to add + */ + public WebHttpHandlerBuilder filters(Collection filters) { if (!ObjectUtils.isEmpty(filters)) { - this.filters.addAll(Arrays.asList(filters)); + this.filters.addAll(filters); } return this; } @@ -164,8 +166,16 @@ public class WebHttpHandlerBuilder { * @param exceptionHandlers the exception handlers */ public WebHttpHandlerBuilder exceptionHandlers(WebExceptionHandler... exceptionHandlers) { + return exceptionHandlers(Arrays.asList(exceptionHandlers)); + } + + /** + * Add the given exception handler to apply at the end of request processing. + * @param exceptionHandlers the exception handlers + */ + public WebHttpHandlerBuilder exceptionHandlers(List exceptionHandlers) { if (!ObjectUtils.isEmpty(exceptionHandlers)) { - this.exceptionHandlers.addAll(Arrays.asList(exceptionHandlers)); + this.exceptionHandlers.addAll(exceptionHandlers); } return this; } @@ -201,4 +211,33 @@ public class WebHttpHandlerBuilder { return httpHandler; } + + private static class AutowiredFiltersContainer { + + private List filters; + + @Autowired(required = false) + public void setFilters(List filters) { + this.filters = filters; + } + + public List getFilters() { + return this.filters; + } + } + + private static class AutowiredExceptionHandlersContainer { + + private List exceptionHandlers; + + @Autowired(required = false) + public void setExceptionHandlers(List exceptionHandlers) { + this.exceptionHandlers = exceptionHandlers; + } + + public List getExceptionHandlers() { + return this.exceptionHandlers; + } + } + } diff --git a/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java b/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java new file mode 100644 index 0000000000..44a3f896dc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/adapter/WebHttpHandlerBuilderTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.adapter; + +import java.nio.charset.StandardCharsets; + +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.annotation.Order; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebExceptionHandler; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebHandler; + +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link WebHttpHandlerBuilder}. + * @author Rossen Stoyanchev + */ +public class WebHttpHandlerBuilderTests { + + + @Test // SPR-15074 + public void orderedWebFilterBeans() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(OrderedWebFilterBeanConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).blockMillis(5000); + + assertEquals("FilterB::FilterA", response.getBodyAsString().blockMillis(5000)); + } + + @Test // SPR-15074 + public void orderedWebExceptionHandlerBeans() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(OrderedExceptionHandlerBeanConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).blockMillis(5000); + + assertEquals("ExceptionHandlerB", response.getBodyAsString().blockMillis(5000)); + } + + @Test + public void configWithoutFilters() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(NoFilterConfig.class); + context.refresh(); + + HttpHandler httpHandler = WebHttpHandlerBuilder.applicationContext(context).build(); + + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(); + httpHandler.handle(request, response).blockMillis(5000); + + assertEquals("handled", response.getBodyAsString().blockMillis(5000)); + } + + private static Mono writeToResponse(ServerWebExchange exchange, String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes); + return exchange.getResponse().writeWith(Flux.just(buffer)); + } + + + @Configuration + @SuppressWarnings("unused") + static class OrderedWebFilterBeanConfig { + + private static final String ATTRIBUTE = "attr"; + + @Bean @Order(2) + public WebFilter filterA() { + return createFilter("FilterA"); + } + + @Bean @Order(1) + public WebFilter filterB() { + return createFilter("FilterB"); + } + + @NotNull + private WebFilter createFilter(String name) { + return (exchange, chain) -> { + String value = exchange.getAttribute(ATTRIBUTE).map(v -> v + "::" + name).orElse(name); + exchange.getAttributes().put(ATTRIBUTE, value); + return chain.filter(exchange); + }; + } + + @Bean + public WebHandler webHandler() { + return exchange -> { + String value = exchange.getAttribute(ATTRIBUTE).map(v -> (String) v).orElse("none"); + return writeToResponse(exchange, value); + }; + } + } + + @Configuration + @SuppressWarnings("unused") + static class OrderedExceptionHandlerBeanConfig { + + @Bean + @Order(2) + public WebExceptionHandler exceptionHandlerA() { + return (exchange, ex) -> writeToResponse(exchange, "ExceptionHandlerA"); + } + + @Bean + @Order(1) + public WebExceptionHandler exceptionHandlerB() { + return (exchange, ex) -> writeToResponse(exchange, "ExceptionHandlerB"); + } + + @Bean + public WebHandler webHandler() { + return exchange -> Mono.error(new Exception()); + } + } + + @Configuration + @SuppressWarnings("unused") + static class NoFilterConfig { + + @Bean + public WebHandler webHandler() { + return exchange -> writeToResponse(exchange, "handled"); + } + } + +}