From 6ae2f2fa438e438e2a038fc01a3db908bbe28731 Mon Sep 17 00:00:00 2001 From: Janne Valkealahti Date: Sat, 4 Jul 2015 15:15:27 +0100 Subject: [PATCH] Can now skip initial state - Fixes #71 - If substate is entered directry, we don't go via initial state for particular state machine where transition ends. - Fixing other tests which assumed initial state is entered. - In terms of a history state, initial state is still entered. --- .../access/StateMachineAccess.java | 8 + .../access/StateMachineAccessor.java | 14 ++ .../statemachine/state/StateMachineState.java | 48 +++++ .../support/AbstractStateMachine.java | 33 +++- .../support/StateMachineUtils.java | 3 + .../AbstractStateMachineTests.java | 3 +- .../statemachine/SubStateMachineTests.java | 16 +- .../access/StateMachineAccessTests.java | 13 ++ .../statemachine/state/HistoryStateTests.java | 1 - .../transition/TransitionTests.java | 182 +++++++++++++++++- .../java/demo/showcase/ShowcaseTests.java | 14 +- 11 files changed, 313 insertions(+), 22 deletions(-) diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccess.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccess.java index 7528049d..a24f3cd1 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccess.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccess.java @@ -57,4 +57,12 @@ public interface StateMachineAccess { */ void addStateChangeInterceptor(StateChangeInterceptor interceptor); + /** + * Sets if initial state is enabled when a state machine is + * using sub states. + * + * @param enabled the new initial enabled + */ + void setInitialEnabled(boolean enabled); + } diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccessor.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccessor.java index 25e19b42..08ed61b8 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccessor.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/access/StateMachineAccessor.java @@ -46,4 +46,18 @@ public interface StateMachineAccessor { */ List> withAllRegions(); + /** + * Execute given {@link StateMachineFunction} with a region. + * + * @param stateMachineAccess the state machine access + */ + void doWithRegion(StateMachineFunction> stateMachineAccess); + + /** + * Get a region. + * + * @return the state machine access + */ + StateMachineAccess withRegion(); + } diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/state/StateMachineState.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/state/StateMachineState.java index b64364d7..d06e0517 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/state/StateMachineState.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/state/StateMachineState.java @@ -21,6 +21,8 @@ import java.util.Collection; import org.springframework.messaging.Message; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; +import org.springframework.statemachine.access.StateMachineAccess; +import org.springframework.statemachine.access.StateMachineFunction; import org.springframework.statemachine.action.Action; import org.springframework.statemachine.support.StateMachineUtils; import org.springframework.statemachine.transition.Transition; @@ -161,10 +163,56 @@ public class StateMachineState extends AbstractState { } if (getPseudoState() != null && getPseudoState().getKind() == PseudoStateKind.INITIAL) { + // disable initial state if it looks like we're about + // to transit directory into a non initial state + // we do transit via initial state if we're returning + // via history state + boolean initialEnabled = true; + if (context.getTransition() != null) { + State target = context.getTransition().getTarget(); + PseudoStateKind kind = target.getPseudoState() != null ? target.getPseudoState().getKind() : null; + State findDeepParent = findDeepParent(getSubmachine().getStates(), target); + if (findDeepParent != null && findDeepParent.isSubmachineState()) { + ((StateMachineState) findDeepParent).getSubmachine().getStateMachineAccessor() + .doWithRegion(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.setInitialEnabled(false); + } + }); + } + if (getSubmachine().getStates().contains(target) && kind != PseudoStateKind.HISTORY_SHALLOW + && kind != PseudoStateKind.HISTORY_DEEP) { + initialEnabled = false; + } + + } + // need final for state machine access + final boolean enabled = initialEnabled; + getSubmachine().getStateMachineAccessor().doWithRegion( + new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.setInitialEnabled(enabled); + } + }); getSubmachine().start(); } } + private State findDeepParent(Collection> states, State state) { + for (State s : states) { + if (s.getStates().contains(state)) { + if (s != state) { + return s; + } + } + } + return null; + } + @Override public boolean sendEvent(Message event) { StateMachine machine = getSubmachine(); diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index b5e3a764..93a7e436 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -104,6 +104,8 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo private StateMachineExecutor stateMachineExecutor; + private Boolean initialEnabled = null; + /** * Instantiates a new abstract state machine. * @@ -243,7 +245,7 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo notifyTransitionStart(t); callHandlers(t.getSource(), t.getTarget(), queuedMessage); if (t.getKind() == TransitionKind.INITIAL) { - switchToState(t.getTarget(), queuedMessage, null, getRelayStateMachine()); + switchToState(t.getTarget(), queuedMessage, t, getRelayStateMachine()); notifyStateMachineStarted(getRelayStateMachine()); } else if (t.getKind() != TransitionKind.INTERNAL) { switchToState(t.getTarget(), queuedMessage, t, getRelayStateMachine()); @@ -265,6 +267,10 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } registerPseudoStateListener(); + if (initialEnabled != null && !initialEnabled) { + stateMachineExecutor.setInitialEnabled(false); + } + // start fires first execution which should execute initial transition stateMachineExecutor.start(); } @@ -274,6 +280,7 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo stateMachineExecutor.stop(); notifyStateMachineStopped(this); currentState = null; + initialEnabled = null; } @Override @@ -312,6 +319,14 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo return transitions; } + + @Override + public void setInitialEnabled(boolean enabled) { + if (initialEnabled == null) { + initialEnabled = enabled; + } + } + @SuppressWarnings("unchecked") @Override public StateMachineAccessor getStateMachineAccessor() { @@ -353,6 +368,16 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } return list; } + + @Override + public void doWithRegion(StateMachineFunction> stateMachineAccess) { + stateMachineAccess.apply(AbstractStateMachine.this); + } + + @Override + public StateMachineAccess withRegion() { + return AbstractStateMachine.this; + } }; } @@ -538,7 +563,6 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } entryToState(state, message, transition, stateMachine); notifyStateChanged(notifyFrom, state); - StateContext stateContext = buildStateContext(message, transition, stateMachine); } else if (currentState != null) { if (findDeep != null) { if (exit) { @@ -575,8 +599,11 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } } } + boolean shouldEntry = findDeep != currentState; currentState = findDeep; - entryToState(currentState, message, transition, stateMachine); + if (shouldEntry) { + entryToState(currentState, message, transition, stateMachine); + } if (currentState.isSubmachineState()) { StateMachine submachine = ((AbstractState)currentState).getSubmachine(); diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/StateMachineUtils.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/StateMachineUtils.java index c73444b9..b3061cdb 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/StateMachineUtils.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/StateMachineUtils.java @@ -37,6 +37,9 @@ public abstract class StateMachineUtils { * @return if sub is child of super */ public static boolean isSubstate(State left, State right) { + if (left == null) { + return false; + } Collection> c = left.getStates(); c.remove(left); return c.contains(right); diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/AbstractStateMachineTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/AbstractStateMachineTests.java index 7a6c0459..cb00cdcf 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/AbstractStateMachineTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/AbstractStateMachineTests.java @@ -78,7 +78,8 @@ public abstract class AbstractStateMachineTests { public static enum TestStates2 { BUSY, PLAYING, PAUSED, - IDLE, CLOSED, OPEN + IDLE, CLOSED, OPEN, + PAUSED1, PAUSED2 } public static enum TestStates3 { diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/SubStateMachineTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/SubStateMachineTests.java index c769886e..c2eb6dc9 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/SubStateMachineTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/SubStateMachineTests.java @@ -146,11 +146,11 @@ public class SubStateMachineTests extends AbstractStateMachineTests { assertThat(entryActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(true)); assertThat(exitActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(true)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); - assertThat(entryActionS1.stateContexts.size(), is(2)); + assertThat(entryActionS1.stateContexts.size(), is(1)); assertThat(exitActionS1.stateContexts.size(), is(1)); } @@ -336,9 +336,9 @@ public class SubStateMachineTests extends AbstractStateMachineTests { assertThat(entryActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(true)); assertThat(exitActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(false)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); assertThat(entryActionS1.stateContexts.size(), is(1)); assertThat(exitActionS1.stateContexts.size(), is(0)); @@ -371,11 +371,11 @@ public class SubStateMachineTests extends AbstractStateMachineTests { assertThat(entryActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(true)); assertThat(exitActionS1.onExecuteLatch.await(1, TimeUnit.SECONDS), is(true)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); - assertThat(entryActionS11.stateContexts.size(), is(2)); + assertThat(entryActionS11.stateContexts.size(), is(1)); assertThat(exitActionS11.stateContexts.size(), is(1)); - assertThat(entryActionS1.stateContexts.size(), is(2)); + assertThat(entryActionS1.stateContexts.size(), is(1)); assertThat(exitActionS1.stateContexts.size(), is(1)); } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/access/StateMachineAccessTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/access/StateMachineAccessTests.java index db6d7300..081fa4bf 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/access/StateMachineAccessTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/access/StateMachineAccessTests.java @@ -78,6 +78,15 @@ public class StateMachineAccessTests { list.add(MockStateMachine.this); return list; } + + @Override + public void doWithRegion(StateMachineFunction> stateMachineAccess) { + } + + @Override + public StateMachineAccess withRegion() { + return null; + } }; } @@ -154,6 +163,10 @@ public class StateMachineAccessTests { return null; } + @Override + public void setInitialEnabled(boolean enabled) { + } + } } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/state/HistoryStateTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/state/HistoryStateTests.java index f713c8cb..c2e9b194 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/state/HistoryStateTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/state/HistoryStateTests.java @@ -68,7 +68,6 @@ public class HistoryStateTests extends AbstractStateMachineTests { machine.sendEvent(TestEvents.E2); machine.sendEvent(TestEvents.E3); machine.sendEvent(TestEvents.E4); - assertThat(machine.getState().getIds(), contains(TestStates.S2, TestStates.S21, TestStates.S212)); } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/transition/TransitionTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/transition/TransitionTests.java index 9b172f2e..82ce7b54 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/transition/TransitionTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/transition/TransitionTests.java @@ -180,6 +180,54 @@ public class TransitionTests extends AbstractStateMachineTests { assertThat(machine.getState().getIds(), contains(TestStates.S2)); } + @Test + public void testTransitDirectlyToSubstateSkipInitial() throws InterruptedException { + context.register(BaseConfig.class, Config7.class); + context.refresh(); + assertTrue(context.containsBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE)); + @SuppressWarnings("unchecked") + ObjectStateMachine machine = + context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, ObjectStateMachine.class); + TestListener2 listener = new TestListener2(); + machine.addStateListener(listener); + listener.reset(2); + + machine.start(); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(2)); + assertThat(machine.getState().getIds(), contains(TestStates2.IDLE, TestStates2.CLOSED)); + + listener.reset(0, 2); + machine.sendEvent(TestEvents2.PAUSE); + assertThat(listener.stateEnteredLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateEnteredCount, is(2)); + assertThat(machine.getState().getIds(), contains(TestStates2.BUSY, TestStates2.PAUSED)); + } + + @Test + public void testTransitDeepDirectlyToSubstateSkipInitial() throws InterruptedException { + context.register(BaseConfig.class, Config8.class); + context.refresh(); + assertTrue(context.containsBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE)); + @SuppressWarnings("unchecked") + ObjectStateMachine machine = + context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, ObjectStateMachine.class); + TestListener2 listener = new TestListener2(); + machine.addStateListener(listener); + listener.reset(2); + + machine.start(); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(2)); + assertThat(machine.getState().getIds(), contains(TestStates2.IDLE, TestStates2.CLOSED)); + + listener.reset(0, 4); + machine.sendEvent(TestEvents2.PAUSE); + assertThat(listener.stateEnteredLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateEnteredCount, is(4)); + assertThat(machine.getState().getIds(), contains(TestStates2.BUSY, TestStates2.PAUSED, TestStates2.PAUSED2)); + } + @Configuration @EnableStateMachine public static class Config1 extends EnumStateMachineConfigurerAdapter { @@ -408,6 +456,106 @@ public class TransitionTests extends AbstractStateMachineTests { } + @Configuration + @EnableStateMachine + static class Config7 extends EnumStateMachineConfigurerAdapter { + + @Override + public void configure(StateMachineStateConfigurer states) throws Exception { + states + .withStates() + .initial(TestStates2.IDLE) + .state(TestStates2.IDLE) + .and() + .withStates() + .parent(TestStates2.IDLE) + .initial(TestStates2.CLOSED) + .state(TestStates2.CLOSED) + .state(TestStates2.OPEN) + .and() + .withStates() + .state(TestStates2.BUSY) + .and() + .withStates() + .parent(TestStates2.BUSY) + .initial(TestStates2.PLAYING) + .state(TestStates2.PLAYING) + .state(TestStates2.PAUSED); + + } + + @Override + public void configure(StateMachineTransitionConfigurer transitions) throws Exception { + transitions + .withExternal() + .source(TestStates2.CLOSED) + .target(TestStates2.OPEN) + .event(TestEvents2.EJECT) + .and() + .withExternal() + .source(TestStates2.OPEN) + .target(TestStates2.CLOSED) + .event(TestEvents2.EJECT) + .and() + .withExternal() + .source(TestStates2.CLOSED) + .target(TestStates2.PAUSED) + .event(TestEvents2.PAUSE); + } + + } + + @Configuration + @EnableStateMachine + static class Config8 extends EnumStateMachineConfigurerAdapter { + + @Override + public void configure(StateMachineStateConfigurer states) throws Exception { + states + .withStates() + .initial(TestStates2.IDLE) + .state(TestStates2.IDLE) + .and() + .withStates() + .parent(TestStates2.IDLE) + .initial(TestStates2.CLOSED) + .state(TestStates2.OPEN) + .and() + .withStates() + .state(TestStates2.BUSY) + .and() + .withStates() + .parent(TestStates2.BUSY) + .initial(TestStates2.PLAYING) + .state(TestStates2.PAUSED) + .and() + .withStates() + .parent(TestStates2.PAUSED) + .initial(TestStates2.PAUSED1) + .state(TestStates2.PAUSED2); + } + + @Override + public void configure(StateMachineTransitionConfigurer transitions) throws Exception { + transitions + .withExternal() + .source(TestStates2.CLOSED) + .target(TestStates2.OPEN) + .event(TestEvents2.EJECT) + .and() + .withExternal() + .source(TestStates2.OPEN) + .target(TestStates2.CLOSED) + .event(TestEvents2.EJECT) + .and() + .withExternal() + .source(TestStates2.CLOSED) + .target(TestStates2.PAUSED2) + .event(TestEvents2.PAUSE); + } + + } + static class TestListener extends StateMachineListenerAdapter { volatile CountDownLatch stateChangedLatch = new CountDownLatch(1); @@ -415,8 +563,8 @@ public class TransitionTests extends AbstractStateMachineTests { @Override public void stateChanged(State from, State to) { - stateChangedLatch.countDown(); stateChangedCount++; + stateChangedLatch.countDown(); } public void reset(int c1) { @@ -426,4 +574,36 @@ public class TransitionTests extends AbstractStateMachineTests { } + static class TestListener2 extends StateMachineListenerAdapter { + + volatile CountDownLatch stateChangedLatch = new CountDownLatch(1); + volatile int stateChangedCount = 0; + volatile CountDownLatch stateEnteredLatch = new CountDownLatch(1); + volatile int stateEnteredCount = 0; + + @Override + public void stateChanged(State from, State to) { + stateChangedCount++; + stateChangedLatch.countDown(); + } + + @Override + public void stateEntered(State state) { + stateEnteredCount++; + stateEnteredLatch.countDown(); + } + + public void reset(int c1) { + reset(c1, 0); + } + + public void reset(int c1, int c2) { + stateChangedLatch = new CountDownLatch(c1); + stateChangedCount = 0; + stateEnteredLatch = new CountDownLatch(c2); + stateEnteredCount = 0; + } + + } + } diff --git a/spring-statemachine-samples/showcase/src/test/java/demo/showcase/ShowcaseTests.java b/spring-statemachine-samples/showcase/src/test/java/demo/showcase/ShowcaseTests.java index 8c020106..c24c1833 100644 --- a/spring-statemachine-samples/showcase/src/test/java/demo/showcase/ShowcaseTests.java +++ b/spring-statemachine-samples/showcase/src/test/java/demo/showcase/ShowcaseTests.java @@ -170,12 +170,10 @@ public class ShowcaseTests { public void testII() throws Exception { machine.sendEvent(Events.I); - listener.reset(1, 0, 0); - // TODO: should think if need to bypass - // S211 as initial state and go directly - // to S212. + listener.reset(2, 0, 0); machine.sendEvent(Events.I); - listener.stateChangedLatch.await(1, TimeUnit.SECONDS); + assertThat(listener.stateChangedLatch.await(1, TimeUnit.SECONDS), is(true)); + assertThat(listener.statesEntered.size(), is(3)); assertThat(machine.getState().getIds(), contains(States.S0, States.S2, States.S21, States.S212)); } @@ -216,7 +214,7 @@ public class ShowcaseTests { listener.stateChangedLatch.await(1, TimeUnit.SECONDS); assertThat(machine.getState().getIds(), contains(States.S0, States.S2, States.S21, States.S211)); assertThat(listener.statesExited.size(), is(2)); - assertThat(listener.statesEntered.size(), is(3)); + assertThat(listener.statesEntered.size(), is(2)); } @Test @@ -226,7 +224,7 @@ public class ShowcaseTests { listener.stateChangedLatch.await(1, TimeUnit.SECONDS); assertThat(machine.getState().getIds(), contains(States.S0, States.S2, States.S21, States.S211)); assertThat(listener.statesExited.size(), is(2)); - assertThat(listener.statesEntered.size(), is(4)); + assertThat(listener.statesEntered.size(), is(3)); } @Test @@ -236,7 +234,7 @@ public class ShowcaseTests { listener.stateChangedLatch.await(1, TimeUnit.SECONDS); assertThat(machine.getState().getIds(), contains(States.S0, States.S2, States.S21, States.S211)); assertThat(listener.statesExited.size(), is(2)); - assertThat(listener.statesEntered.size(), is(4)); + assertThat(listener.statesEntered.size(), is(3)); } static class Config {