diff --git a/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioCustomizer.java b/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioCustomizer.java index abdddac9..e1ccc5cc 100644 --- a/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioCustomizer.java +++ b/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioCustomizer.java @@ -16,13 +16,17 @@ package org.springframework.modulith.test; import java.lang.reflect.Method; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.function.Function; +import java.util.function.Supplier; import org.awaitility.core.ConditionFactory; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.InvocationInterceptor; import org.junit.jupiter.api.extension.ReflectiveInvocationContext; import org.springframework.context.ApplicationContext; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.util.Assert; @@ -43,6 +47,30 @@ public interface ScenarioCustomizer extends InvocationInterceptor { */ Function getDefaultCustomizer(Method method, ApplicationContext context); + /** + * Creates a default scenario customizer that will try to find an {@link ExecutorService} in the given + * {@link ApplicationContext} in the following order: + *
    + *
  1. A unique {@link ExecutorService} bean defined
  2. + *
  3. A {@link ThreadPoolTaskExecutor} bean defined (the default Spring Boot creates in case no {@link Executor} is + * explicitly defined in the {@link ApplicationContext}
  4. + *
+ * + * @param context must not be {@literal null}. + * @return will never be {@literal null}. + */ + public static Function forwardExecutorService(ApplicationContext context) { + + Supplier fallback = () -> { + var executor = context.getBeanProvider(ThreadPoolTaskExecutor.class).getIfUnique(); + return executor == null ? null : executor.getThreadPoolExecutor(); + }; + + var executorService = context.getBeanProvider(ExecutorService.class).getIfUnique(fallback); + + return executorService != null ? it -> it.pollExecutorService(executorService) : Function.identity(); + } + /* * (non-Javadoc) * @see org.junit.jupiter.api.extension.InvocationInterceptor#interceptTestTemplateMethod(org.junit.jupiter.api.extension.InvocationInterceptor.Invocation, org.junit.jupiter.api.extension.ReflectiveInvocationContext, org.junit.jupiter.api.extension.ExtensionContext) diff --git a/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioParameterResolver.java b/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioParameterResolver.java index 6816e947..b802db08 100644 --- a/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioParameterResolver.java +++ b/spring-modulith-test/src/main/java/org/springframework/modulith/test/ScenarioParameterResolver.java @@ -77,7 +77,8 @@ class ScenarioParameterResolver implements ParameterResolver, AfterEachCallback var operations = resolveTransactionTemplate(context); var events = (AssertablePublishedEvents) delegate.resolveParameter(parameterContext, extensionContext); - return new Scenario(operations, context, events); + return new Scenario(operations, context, events) + .setDefaultCustomizer(ScenarioCustomizer.forwardExecutorService(context)); } private TransactionTemplate resolveTransactionTemplate(ApplicationContext context) { diff --git a/spring-modulith-test/src/test/java/org/springframework/modulith/test/ScenarioCustomizerIntegrationTests.java b/spring-modulith-test/src/test/java/org/springframework/modulith/test/ScenarioCustomizerIntegrationTests.java index 072693ce..6a07c0bc 100644 --- a/spring-modulith-test/src/test/java/org/springframework/modulith/test/ScenarioCustomizerIntegrationTests.java +++ b/spring-modulith-test/src/test/java/org/springframework/modulith/test/ScenarioCustomizerIntegrationTests.java @@ -19,12 +19,17 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; import java.lang.reflect.Method; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.function.Function; +import org.awaitility.Awaitility; import org.awaitility.core.ConditionFactory; +import org.awaitility.core.ExecutorLifecycle; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -50,8 +55,15 @@ class ScenarioCustomizerIntegrationTests { TransactionTemplate transactionTemplate() { return mock(TransactionTemplate.class); } + + @Bean + ExecutorService executorService() { + return Executors.newSingleThreadExecutor(); + } } + @Autowired ExecutorService executorService; + @BeforeEach void setUp() { TestScenarioCustomizer.invoked = false; @@ -61,6 +73,8 @@ class ScenarioCustomizerIntegrationTests { void customizerGetsAppliedForScenarioParameter(Scenario scenario) { assertThat(TestScenarioCustomizer.invoked).isTrue(); + assertThat(TestScenarioCustomizer.SAMPLE).isNotNull(); + assertThat(ReflectionTestUtils.getField(scenario, "defaultCustomizer")) .isSameAs(TestScenarioCustomizer.SAMPLE); } @@ -70,9 +84,23 @@ class ScenarioCustomizerIntegrationTests { assertThat(TestScenarioCustomizer.invoked).isFalse(); } + @Test // GH-165 + @SuppressWarnings("unchecked") + void forwardsExecutorServiceFromApplicationContext(Scenario scenario) { + + var customizer = (Function) ReflectionTestUtils.getField(scenario, + "defaultCustomizer"); + + var factory = customizer.apply(Awaitility.await()); + var lifecycle = (ExecutorLifecycle) ReflectionTestUtils.getField(factory, "executorLifecycle"); + + assertThat(lifecycle).isNotNull(); + assertThat(lifecycle.supplyExecutorService()).isEqualTo(executorService); + } + static class TestScenarioCustomizer implements ScenarioCustomizer { - static Function SAMPLE = it -> it; + static Function SAMPLE; static boolean invoked = false; @Override @@ -81,6 +109,8 @@ class ScenarioCustomizerIntegrationTests { invoked = true; + SAMPLE = ScenarioCustomizer.forwardExecutorService(context); + return SAMPLE; } }