diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index 665eb281..a27414e8 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -28,7 +28,9 @@ import java.util.UUID; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.core.OrderComparator; import org.springframework.core.annotation.AnnotationUtils; @@ -257,7 +259,8 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo transitions, triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); if (getBeanFactory() != null) { executor.setBeanFactory(getBeanFactory()); - } else if (getTaskExecutor() != null){ + } + if (getTaskExecutor() != null){ executor.setTaskExecutor(getTaskExecutor()); } executor.afterPropertiesSet(); @@ -280,6 +283,19 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo stateMachineExecutor = executor; } + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + // last change to set factory because this maybe be called per + // BeanFactoryAware if machine is created as Bean and configurers + // didn't set it. + if (getBeanFactory() == null) { + super.setBeanFactory(beanFactory); + if (stateMachineExecutor instanceof BeanFactoryAware) { + ((BeanFactoryAware)stateMachineExecutor).setBeanFactory(beanFactory); + } + } + } + @Override protected void doStart() { // if state is set assume nothing to do diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/LifecycleObjectSupport.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/LifecycleObjectSupport.java index 1dde7fca..63b19f5c 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/LifecycleObjectSupport.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/LifecycleObjectSupport.java @@ -15,6 +15,7 @@ */ package org.springframework.statemachine.support; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; @@ -55,10 +56,17 @@ public abstract class LifecycleObjectSupport implements InitializingBean, SmartL // to access bean factory private volatile BeanFactory beanFactory; + // protect InitializingBean for single call + private final AtomicBoolean afterPropertiesSetCalled = new AtomicBoolean(false); + @Override public final void afterPropertiesSet() { try { - this.onInit(); + if (afterPropertiesSetCalled.compareAndSet(false, true)) { + this.onInit(); + } else { + log.debug("afterPropertiesSet() is already called, not calling onInit()"); + } } catch (Exception e) { if (e instanceof RuntimeException) { throw (RuntimeException) e; @@ -68,7 +76,7 @@ public abstract class LifecycleObjectSupport implements InitializingBean, SmartL } @Override - public final void setBeanFactory(BeanFactory beanFactory) throws BeansException { + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { Assert.notNull(beanFactory, "beanFactory must not be null"); if(log.isDebugEnabled()) { log.debug("Setting bean factory: " + beanFactory + " for " + this); diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ConfigurationTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ConfigurationTests.java index b18d375b..e9738f68 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ConfigurationTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ConfigurationTests.java @@ -17,6 +17,7 @@ package org.springframework.statemachine.config; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.sameInstance; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -28,6 +29,8 @@ import java.util.List; import org.junit.Test; import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -35,9 +38,11 @@ import org.springframework.core.task.SyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.statemachine.AbstractStateMachineTests; import org.springframework.statemachine.ObjectStateMachine; +import org.springframework.statemachine.StateMachine; import org.springframework.statemachine.StateMachineSystemConstants; import org.springframework.statemachine.TestUtils; import org.springframework.statemachine.action.Action; +import org.springframework.statemachine.config.StateMachineBuilder.Builder; import org.springframework.statemachine.config.builders.StateMachineConfigurationConfigurer; import org.springframework.statemachine.config.builders.StateMachineStateConfigurer; import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer; @@ -162,6 +167,81 @@ public class ConfigurationTests extends AbstractStateMachineTests { context.refresh(); } + @Test + public void testTaskExecutor1() throws Exception { + // set in builder, no bf or taskExecutor bean registered + context.register(Config14.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + + Object executorFromMachine = TestUtils.readField("taskExecutor", stateMachine); + Object stateMachineExecutor = TestUtils.readField("stateMachineExecutor", stateMachine); + Object executorFromExecutor = TestUtils.readField("taskExecutor", stateMachineExecutor); + + assertThat(executorFromMachine, sameInstance(Config14.taskExecutor)); + assertThat(executorFromExecutor, sameInstance(Config14.taskExecutor)); + + assertThat(executorFromMachine, notNullValue()); + assertThat(executorFromExecutor, notNullValue()); + assertThat(executorFromMachine, sameInstance(executorFromExecutor)); + } + + @Test + public void testTaskExecutor2() throws Exception { + // set as bean, should get from bf + context.register(BaseConfig.class, Config15.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + assertThat(context.containsBean(StateMachineSystemConstants.TASK_EXECUTOR_BEAN_NAME), is(true)); + + Object stateMachineExecutor = TestUtils.readField("stateMachineExecutor", stateMachine); + + Object executorFromMachine = TestUtils.callMethod("getTaskExecutor", stateMachine); + Object executorFromExecutor = TestUtils.callMethod("getTaskExecutor", stateMachineExecutor); + + assertThat(executorFromMachine, notNullValue()); + assertThat(executorFromExecutor, notNullValue()); + assertThat(executorFromMachine, sameInstance(executorFromExecutor)); + } + + @Test + public void testBeanFactory1() throws Exception { + // should come from context + context.register(Config15.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + + Object stateMachineExecutor = TestUtils.readField("stateMachineExecutor", stateMachine); + + Object bfFromMachine = TestUtils.callMethod("getBeanFactory", stateMachine); + Object bfFromExecutor = TestUtils.callMethod("getBeanFactory", stateMachineExecutor); + + assertThat(bfFromMachine, notNullValue()); + assertThat(bfFromExecutor, notNullValue()); + assertThat(bfFromMachine, sameInstance(bfFromExecutor)); + } + + @Test + public void testBeanFactory2() throws Exception { + // set bf in builder + context.register(Config16.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + + Object stateMachineExecutor = TestUtils.readField("stateMachineExecutor", stateMachine); + + Object bfFromMachine = TestUtils.callMethod("getBeanFactory", stateMachine); + Object bfFromExecutor = TestUtils.callMethod("getBeanFactory", stateMachineExecutor); + + assertThat(bfFromMachine, notNullValue()); + assertThat(bfFromExecutor, notNullValue()); + assertThat(bfFromMachine, sameInstance(Config16.beanFactory)); + } + @Configuration @EnableStateMachine public static class Config1 extends EnumStateMachineConfigurerAdapter { @@ -557,4 +637,82 @@ public class ConfigurationTests extends AbstractStateMachineTests { public static class Config13 { } + @Configuration + public static class Config14 { + + public static TaskExecutor taskExecutor = new SyncTaskExecutor(); + + @Bean + StateMachine stateMachine() throws Exception { + Builder builder = StateMachineBuilder.builder(); + builder.configureConfiguration() + .withConfiguration() + .autoStartup(false) + .taskExecutor(taskExecutor); + builder.configureStates() + .withStates() + .initial("S1").state("S2"); + builder.configureTransitions() + .withExternal() + .source("S1").target("S2").event("E1") + .and() + .withExternal() + .source("S2").target("S1").event("E2"); + StateMachine stateMachine = builder.build(); + return stateMachine; + } + + } + + @Configuration + public static class Config15 { + + @Bean + StateMachine stateMachine() throws Exception { + Builder builder = StateMachineBuilder.builder(); + builder.configureConfiguration() + .withConfiguration() + .autoStartup(false); + builder.configureStates() + .withStates() + .initial("S1").state("S2"); + builder.configureTransitions() + .withExternal() + .source("S1").target("S2").event("E1") + .and() + .withExternal() + .source("S2").target("S1").event("E2"); + StateMachine stateMachine = builder.build(); + return stateMachine; + } + + } + + @Configuration + public static class Config16 { + + public static BeanFactory beanFactory = new DefaultListableBeanFactory(); + + @Bean + StateMachine stateMachine() throws Exception { + Builder builder = StateMachineBuilder.builder(); + builder.configureConfiguration() + .withConfiguration() + .autoStartup(false) + .beanFactory(beanFactory); + builder.configureStates() + .withStates() + .initial("S1").state("S2"); + builder.configureTransitions() + .withExternal() + .source("S1").target("S2").event("E1") + .and() + .withExternal() + .source("S2").target("S1").event("E2"); + StateMachine stateMachine = builder.build(); + return stateMachine; + } + + } + } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ManualBuilderContextTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ManualBuilderContextTests.java new file mode 100644 index 00000000..bbdb2a53 --- /dev/null +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/config/ManualBuilderContextTests.java @@ -0,0 +1,162 @@ +/* + * Copyright 2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.statemachine.config; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.statemachine.AbstractStateMachineTests; +import org.springframework.statemachine.StateMachine; +import org.springframework.statemachine.config.StateMachineBuilder.Builder; +import org.springframework.statemachine.listener.StateMachineListenerAdapter; +import org.springframework.statemachine.state.State; + +public class ManualBuilderContextTests extends AbstractStateMachineTests { + + @Override + protected AnnotationConfigApplicationContext buildContext() { + return new AnnotationConfigApplicationContext(); + } + + @Test + public void testAsBeanViaBuilder1() throws Exception { + context.register(Config1.class); + context.refresh(); + TestListener listener = context.getBean(TestListener.class); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + assertThat(listener.stateMachineStartedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(stateMachine.getState().getIds(), containsInAnyOrder("S1")); + listener.reset(1); + stateMachine.sendEvent("E1"); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(stateMachine.getState().getIds(), containsInAnyOrder("S2")); + } + + @Test + public void testAsBeanViaBuilder2() throws Exception { + context.register(Config2.class); + context.refresh(); + TestListener listener = context.getBean(TestListener.class); + @SuppressWarnings("unchecked") + StateMachine stateMachine = context.getBean(StateMachine.class); + stateMachine.start(); + assertThat(listener.stateMachineStartedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(stateMachine.getState().getIds(), containsInAnyOrder("S1")); + listener.reset(1); + stateMachine.sendEvent("E1"); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(stateMachine.getState().getIds(), containsInAnyOrder("S2")); + } + + @Configuration + static class Config1 { + + @Bean + StateMachine stateMachine() throws Exception { + Builder builder = StateMachineBuilder.builder(); + builder.configureConfiguration() + .withConfiguration() + .autoStartup(true) + .listener(testListener()) + .taskExecutor(new SyncTaskExecutor()); + builder.configureStates() + .withStates() + .initial("S1").state("S2"); + builder.configureTransitions() + .withExternal() + .source("S1").target("S2").event("E1") + .and() + .withExternal() + .source("S2").target("S1").event("E2"); + StateMachine stateMachine = builder.build(); + return stateMachine; + } + + @Bean + TestListener testListener() { + return new TestListener(); + } + + } + + @Configuration + static class Config2 { + + @Bean + StateMachine stateMachine() throws Exception { + Builder builder = StateMachineBuilder.builder(); + builder.configureConfiguration() + .withConfiguration() + .autoStartup(false) + .listener(testListener()) + .taskExecutor(new SyncTaskExecutor()); + builder.configureStates() + .withStates() + .initial("S1").state("S2"); + builder.configureTransitions() + .withExternal() + .source("S1").target("S2").event("E1") + .and() + .withExternal() + .source("S2").target("S1").event("E2"); + StateMachine stateMachine = builder.build(); + return stateMachine; + } + + @Bean + TestListener testListener() { + return new TestListener(); + } + + } + + private static class TestListener extends StateMachineListenerAdapter { + + volatile CountDownLatch stateMachineStartedLatch = new CountDownLatch(1); + volatile CountDownLatch stateChangedLatch = new CountDownLatch(1); + volatile int stateChangedCount = 0; + + @Override + public void stateMachineStarted(StateMachine stateMachine) { + stateMachineStartedLatch.countDown(); + } + + @Override + public void stateChanged(State from, State to) { + stateChangedCount++; + stateChangedLatch.countDown(); + } + + public void reset(int a1) { + stateChangedCount = 0; + stateChangedLatch = new CountDownLatch(a1); + } + + } + +} diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/event/ContextEventTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/event/ContextEventTests.java index f93b7355..3a721eb7 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/event/ContextEventTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/event/ContextEventTests.java @@ -100,9 +100,11 @@ public class ContextEventTests extends AbstractStateMachineTests { context.refresh(); ObjectStateMachine machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, ObjectStateMachine.class); - machine.start(); - machine.sendEvent(TestEvents.E1); StateMachineApplicationEventListener listener = context.getBean(StateMachineApplicationEventListener.class); + machine.start(); + listener.latch.await(1, TimeUnit.SECONDS); + listener.reset(); + machine.sendEvent(TestEvents.E1); listener.latch.await(1, TimeUnit.SECONDS); assertThat(listener.count, greaterThan(1)); } @@ -188,14 +190,19 @@ public class ContextEventTests extends AbstractStateMachineTests { static class StateMachineApplicationEventListener implements ApplicationListener { - CountDownLatch latch = new CountDownLatch(1); - int count = 0; + volatile CountDownLatch latch = new CountDownLatch(1); + volatile int count = 0; @Override public void onApplicationEvent(StateMachineEvent event) { count++; latch.countDown(); } + + public void reset() { + count = 0; + latch = new CountDownLatch(1); + } } }