diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeMutatorWebFilter.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeMutatorWebFilter.java index 148ceafec5..709fa74136 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeMutatorWebFilter.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeMutatorWebFilter.java @@ -16,10 +16,9 @@ package org.springframework.test.web.reactive.server; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import java.util.function.UnaryOperator; import reactor.core.publisher.Mono; @@ -38,9 +37,10 @@ import org.springframework.web.server.WebFilterChain; */ class ExchangeMutatorWebFilter implements WebFilter { - private volatile List> globalMutators = new ArrayList<>(4); + private volatile Function globalMutator; - private final Map> requestMutators = new ConcurrentHashMap<>(4); + private final Map> perRequestMutators = + new ConcurrentHashMap<>(4); /** @@ -49,7 +49,7 @@ class ExchangeMutatorWebFilter implements WebFilter { */ public void register(UnaryOperator mutator) { Assert.notNull(mutator, "'mutator' is required"); - this.globalMutators.add(mutator); + this.globalMutator = this.globalMutator != null ? this.globalMutator.andThen(mutator) : mutator; } /** @@ -58,19 +58,20 @@ class ExchangeMutatorWebFilter implements WebFilter { * @param mutator the transformation function */ public void register(String requestId, UnaryOperator mutator) { - this.requestMutators.put(requestId, mutator); + this.perRequestMutators.compute(requestId, + (s, value) -> value != null ? value.andThen(mutator) : mutator); } @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - for (UnaryOperator mutator : this.globalMutators) { - exchange = mutator.apply(exchange); + if (this.globalMutator != null) { + exchange = this.globalMutator.apply(exchange); } String requestId = WiretapConnector.getRequestId(exchange.getRequest().getHeaders()); - UnaryOperator mutator = this.requestMutators.remove(requestId); + Function mutator = this.perRequestMutators.remove(requestId); if (mutator != null) { exchange = mutator.apply(exchange); } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java index a1f855740c..bad8a5e3c0 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java @@ -28,6 +28,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestAttribute; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; @@ -54,11 +55,11 @@ public class ApplicationContextTests { context.refresh(); this.client = WebTestClient.bindToApplicationContext(context) - .exchangeMutator(identitySetup("Pablo")) + .exchangeMutator(principal("Pablo")) .build(); } - private UnaryOperator identitySetup(String userName) { + private UnaryOperator principal(String userName) { return exchange -> { Principal user = mock(Principal.class); when(user.getName()).thenReturn(userName); @@ -69,19 +70,37 @@ public class ApplicationContextTests { @Test public void basic() throws Exception { - this.client.get().uri("/test") + this.client.get().uri("/principal") .exchange() .expectStatus().isOk() .expectBody(String.class).value().isEqualTo("Hello Pablo!"); } @Test - public void perRequestIdentityOverride() throws Exception { - this.client.exchangeMutator(identitySetup("Giovani")) - .get().uri("/test") + public void perRequestExchangeMutator() throws Exception { + this.client.exchangeMutator(principal("Giovanni")) + .get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Giovani!"); + .expectBody(String.class).value().isEqualTo("Hello Giovanni!"); + } + + @Test + public void perRequestMultipleExchangeMutators() throws Exception { + this.client + .exchangeMutator(attribute("attr1", "foo")) + .exchangeMutator(attribute("attr2", "bar")) + .get().uri("/attributes") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).value().isEqualTo("foo+bar"); + } + + private UnaryOperator attribute(String attrName, String attrValue) { + return exchange -> { + exchange.getAttributes().put(attrName, attrValue); + return exchange; + }; } @@ -99,10 +118,15 @@ public class ApplicationContextTests { @RestController static class TestController { - @GetMapping("/test") + @GetMapping("/principal") public String handle(Principal principal) { return "Hello " + principal.getName() + "!"; } + + @GetMapping("/attributes") + public String handle(@RequestAttribute String attr1, @RequestAttribute String attr2) { + return attr1 + "+" + attr2; + } } } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java index 4c30e3ec83..8b4b63cd10 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java @@ -24,6 +24,7 @@ import reactor.core.publisher.Mono; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestAttribute; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; @@ -40,11 +41,11 @@ public class ControllerTests { private final WebTestClient client = WebTestClient .bindToController(new TestController()) - .exchangeMutator(identitySetup("Pablo")) + .exchangeMutator(principal("Pablo")) .build(); - private UnaryOperator identitySetup(String userName) { + private UnaryOperator principal(String userName) { return exchange -> { Principal user = mock(Principal.class); when(user.getName()).thenReturn(userName); @@ -55,29 +56,52 @@ public class ControllerTests { @Test public void basic() throws Exception { - this.client.get().uri("/test") + this.client.get().uri("/principal") .exchange() .expectStatus().isOk() .expectBody(String.class).value().isEqualTo("Hello Pablo!"); } @Test - public void perRequestIdentityOverride() throws Exception { - this.client.exchangeMutator(identitySetup("Giovani")) - .get().uri("/test") + public void perRequestExchangeMutator() throws Exception { + this.client.exchangeMutator(principal("Giovanni")) + .get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Giovani!"); + .expectBody(String.class).value().isEqualTo("Hello Giovanni!"); + } + + @Test + public void perRequestMultipleExchangeMutators() throws Exception { + this.client + .exchangeMutator(attribute("attr1", "foo")) + .exchangeMutator(attribute("attr2", "bar")) + .get().uri("/attributes") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).value().isEqualTo("foo+bar"); + } + + private UnaryOperator attribute(String attrName, String attrValue) { + return exchange -> { + exchange.getAttributes().put(attrName, attrValue); + return exchange; + }; } @RestController static class TestController { - @GetMapping("/test") + @GetMapping("/principal") public String handle(Principal principal) { return "Hello " + principal.getName() + "!"; } + + @GetMapping("/attributes") + public String handle(@RequestAttribute String attr1, @RequestAttribute String attr2) { + return attr1 + "+" + attr2; + } } }