diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index 7e4b15a6..40552cf7 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -255,8 +255,8 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } } - DefaultStateMachineExecutor executor = new DefaultStateMachineExecutor(this, getRelayStateMachine(), extendedState, - transitions, triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); + DefaultStateMachineExecutor executor = new DefaultStateMachineExecutor(this, getRelayStateMachine(), transitions, + triggerToTransitionMap, triggerlessTransitions, initialTransition, initialEvent); if (getBeanFactory() != null) { executor.setBeanFactory(getBeanFactory()); } diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java index ba6c4d91..2e2bd220 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java @@ -35,7 +35,6 @@ import org.springframework.context.Lifecycle; import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; -import org.springframework.statemachine.ExtendedState; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; import org.springframework.statemachine.StateMachineSystemConstants; @@ -62,8 +61,6 @@ public class DefaultStateMachineExecutor extends LifecycleObjectSupport im private final StateMachine relayStateMachine; - private final ExtendedState extendedState; - private final Queue> eventQueue = new ConcurrentLinkedQueue>(); private final LinkedList> deferList = new LinkedList>(); @@ -100,7 +97,6 @@ public class DefaultStateMachineExecutor extends LifecycleObjectSupport im * * @param stateMachine the state machine * @param relayStateMachine the relay state machine - * @param extendedState the extended state * @param transitions the transitions * @param triggerToTransitionMap the trigger to transition map * @param triggerlessTransitions the triggerless transitions @@ -108,11 +104,10 @@ public class DefaultStateMachineExecutor extends LifecycleObjectSupport im * @param initialEvent the initial event */ public DefaultStateMachineExecutor(StateMachine stateMachine, StateMachine relayStateMachine, - ExtendedState extendedState, Collection> transitions, Map, Transition> triggerToTransitionMap, + Collection> transitions, Map, Transition> triggerToTransitionMap, List> triggerlessTransitions, Transition initialTransition, Message initialEvent) { this.stateMachine = stateMachine; this.relayStateMachine = relayStateMachine; - this.extendedState = extendedState; this.triggerToTransitionMap = triggerToTransitionMap; this.triggerlessTransitions = triggerlessTransitions; this.transitions = transitions; @@ -400,7 +395,7 @@ public class DefaultStateMachineExecutor extends LifecycleObjectSupport im // we want to keep the originating sm id map.put(StateMachineSystemConstants.STATEMACHINE_IDENTIFIER, stateMachine.getId()); } - return new DefaultStateContext(event, new MessageHeaders(map), extendedState, transition, stateMachine); + return new DefaultStateContext(event, new MessageHeaders(map), stateMachine.getExtendedState(), transition, stateMachine); } private void registerTriggerListener() { diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java index cabd47e8..553b4268 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java @@ -17,6 +17,7 @@ package org.springframework.statemachine; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; import java.util.ArrayList; @@ -33,6 +34,7 @@ import org.springframework.statemachine.access.StateMachineFunction; import org.springframework.statemachine.action.Action; import org.springframework.statemachine.config.EnableStateMachine; import org.springframework.statemachine.config.EnumStateMachineConfigurerAdapter; +import org.springframework.statemachine.config.builders.StateMachineConfigurationConfigurer; import org.springframework.statemachine.config.builders.StateMachineStateConfigurer; import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer; import org.springframework.statemachine.guard.Guard; @@ -190,6 +192,37 @@ public class StateMachineResetTests extends AbstractStateMachineTests { assertThat(machine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31)); } + @Test + public void testResetUpdateExtendedStateVariables() { + context.register(Config3.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, StateMachine.class); + + assertThat((Integer)machine.getExtendedState().getVariables().get("count"), nullValue()); + machine.sendEvent(Events.A); + assertThat((Integer)machine.getExtendedState().getVariables().get("count"), is(1)); + + machine.stop(); + Map variables = new HashMap(); + variables.putAll(machine.getExtendedState().getVariables()); + ExtendedState extendedState = new DefaultExtendedState(variables); + DefaultStateMachineContext stateMachineContext = new DefaultStateMachineContext(States.S0, null, null, extendedState); + + machine.getStateMachineAccessor().doWithAllRegions(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.resetStateMachine(stateMachineContext); + } + }); + + machine.start(); + assertThat((Integer)machine.getExtendedState().getVariables().get("count"), is(1)); + machine.sendEvent(Events.A); + assertThat((Integer)machine.getExtendedState().getVariables().get("count"), is(2)); + } + @Configuration @EnableStateMachine static class Config1 extends EnumStateMachineConfigurerAdapter { @@ -311,12 +344,60 @@ public class StateMachineResetTests extends AbstractStateMachineTests { } + @Configuration + @EnableStateMachine + static class Config3 extends EnumStateMachineConfigurerAdapter { + + @Override + public void configure(StateMachineConfigurationConfigurer config) + throws Exception { + config + .withConfiguration() + .autoStartup(true); + } + + @Override + public void configure(StateMachineStateConfigurer states) + throws Exception { + states + .withStates() + .initial(States.S0); + } + + @Override + public void configure(StateMachineTransitionConfigurer transitions) + throws Exception { + transitions + .withInternal() + .source(States.S0) + .event(Events.A) + .action(updateAction()); + } + + @Bean + public Action updateAction() { + return new Action() { + + @Override + public void execute(StateContext context) { + Integer count = context.getExtendedState().get("count", Integer.class); + if (count == null) { + context.getExtendedState().getVariables().put("count", 1); + } else { + context.getExtendedState().getVariables().put("count", (count + 1)); + } + } + }; + } + + } + public static enum States { - S0, S1, S11, S12, S2, S21, S211, S212 + S0, S1, S11, S12, S2, S21, S211, S212 } public static enum Events { - A, B, C, D, E, F, G, H, I + A, B, C, D, E, F, G, H, I } private static class FooAction implements Action {