From 280e77347ed84060b82393651d74e226c42f98a5 Mon Sep 17 00:00:00 2001 From: Janne Valkealahti Date: Fri, 9 Oct 2015 17:35:04 +0100 Subject: [PATCH] Add better support for region persist - Various changes to have a better support for persisting regions. - Fixes for TasksHandler. - Fixes #94 --- .../support/AbstractStateMachine.java | 25 +++ .../support/DefaultStateMachineContext.java | 4 +- .../support/StateMachineInterceptorList.java | 5 + .../statemachine/StateMachineResetTests.java | 33 +++- .../recipes/tasks/TasksHandler.java | 116 +++++++++++- .../recipes/TasksHandlerTests.java | 178 ++++++++++++++++++ .../recipes/TestStateMachinePersist.java | 62 ++++++ 7 files changed, 416 insertions(+), 7 deletions(-) create mode 100644 spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TestStateMachinePersist.java diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index a077c0db..a992b644 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -287,6 +287,10 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo } stateMachineExecutor.setInitialEnabled(false); stateMachineExecutor.start(); + // assume that state was set/reseted so we need to + // dispatch started event which would net getting + // dispatched via executor + notifyStateMachineStarted(getRelayStateMachine()); return; } registerPseudoStateListener(); @@ -467,10 +471,16 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo @Override public void resetStateMachine(StateMachineContext stateMachineContext) { + if (stateMachineContext == null) { + return; + } if (log.isDebugEnabled()) { log.debug("Request to reset state machine: stateMachine=[" + this + "] stateMachineContext=[" + stateMachineContext + "]"); } S state = stateMachineContext.getState(); + if (state == null) { + return; + } boolean stateSet = false; for (State s : getStates()) { for (State ss : s.getStates()) { @@ -480,6 +490,15 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo // needed if we only transit to super state or reset regions if (s.isSubmachineState()) { StateMachine submachine = ((AbstractState)s).getSubmachine(); + for (final StateMachineContext child : stateMachineContext.getChilds()) { + submachine.getStateMachineAccessor().doWithRegion(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.resetStateMachine(child); + } + }); + } submachine.start(); } else if (s.isOrthogonal() && stateMachineContext.getChilds() != null) { Collection> regions = ((AbstractState)s).getRegions(); @@ -591,10 +610,16 @@ public abstract class AbstractStateMachine extends StateMachineObjectSuppo || kind == PseudoStateKind.HISTORY_DEEP) { StateContext stateContext = buildStateContext(message, transition, stateMachine); State toState = state.getPseudoState().entry(stateContext); + + if (kind == PseudoStateKind.CHOICE) { + callPreStateChangeInterceptors(toState, message, transition, stateMachine); + } + setCurrentState(toState, message, transition, true, stateMachine); } else if (kind == PseudoStateKind.FORK) { ForkPseudoState fps = (ForkPseudoState) state.getPseudoState(); for (State ss : fps.getForks()) { + callPreStateChangeInterceptors(ss, message, transition, stateMachine); setCurrentState(ss, message, transition, false, stateMachine); } } else { diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineContext.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineContext.java index c01bbf1a..b6993c33 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineContext.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineContext.java @@ -95,8 +95,8 @@ public class DefaultStateMachineContext implements StateMachineContext { return exception; } + @Override + public String toString() { + return "StateMachineInterceptorList [interceptors=" + interceptors + "]"; + } + } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java index 492e5094..cabd47e8 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/StateMachineResetTests.java @@ -129,7 +129,7 @@ public class StateMachineResetTests extends AbstractStateMachineTests { } @Test - public void testResetRegions() { + public void testResetRegions1() { context.register(Config2.class); context.refresh(); @SuppressWarnings("unchecked") @@ -159,6 +159,37 @@ public class StateMachineResetTests extends AbstractStateMachineTests { assertThat(machine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31)); } + @Test + public void testResetRegions2() { + context.register(Config2.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, StateMachine.class); + + DefaultStateMachineContext stateMachineContext1 = + new DefaultStateMachineContext(TestStates.S21, null, null, null); + DefaultStateMachineContext stateMachineContext2 = + new DefaultStateMachineContext(TestStates.S31, null, null, null); + + List> childs = new ArrayList>(); + childs.add(stateMachineContext1); + childs.add(stateMachineContext2); + + DefaultStateMachineContext stateMachineContext = + new DefaultStateMachineContext(childs, TestStates.S2, null, null, null); + + machine.getStateMachineAccessor().doWithAllRegions(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.resetStateMachine(stateMachineContext); + } + }); + + machine.start(); + assertThat(machine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31)); + } + @Configuration @EnableStateMachine static class Config1 extends EnumStateMachineConfigurerAdapter { diff --git a/spring-statemachine-recipes/src/main/java/org/springframework/statemachine/recipes/tasks/TasksHandler.java b/spring-statemachine-recipes/src/main/java/org/springframework/statemachine/recipes/tasks/TasksHandler.java index 8e1e00aa..b8d0c989 100644 --- a/spring-statemachine-recipes/src/main/java/org/springframework/statemachine/recipes/tasks/TasksHandler.java +++ b/spring-statemachine-recipes/src/main/java/org/springframework/statemachine/recipes/tasks/TasksHandler.java @@ -25,11 +25,15 @@ import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.task.TaskExecutor; +import org.springframework.messaging.Message; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; +import org.springframework.statemachine.StateMachineContext; import org.springframework.statemachine.StateMachineException; import org.springframework.statemachine.StateMachinePersist; +import org.springframework.statemachine.access.StateMachineAccess; +import org.springframework.statemachine.access.StateMachineFunction; import org.springframework.statemachine.action.Action; import org.springframework.statemachine.config.StateMachineBuilder; import org.springframework.statemachine.config.builders.StateMachineStateConfigurer; @@ -37,9 +41,16 @@ import org.springframework.statemachine.config.builders.StateMachineTransitionCo import org.springframework.statemachine.guard.Guard; import org.springframework.statemachine.listener.AbstractCompositeListener; import org.springframework.statemachine.recipes.support.RunnableAction; +import org.springframework.statemachine.state.PseudoStateKind; +import org.springframework.statemachine.state.State; +import org.springframework.statemachine.support.DefaultStateMachineContext; +import org.springframework.statemachine.support.StateMachineInterceptor; +import org.springframework.statemachine.support.StateMachineInterceptorAdapter; +import org.springframework.statemachine.support.StateMachineUtils; import org.springframework.statemachine.support.tree.Tree; import org.springframework.statemachine.support.tree.Tree.Node; import org.springframework.statemachine.support.tree.TreeTraverser; +import org.springframework.statemachine.transition.Transition; /** * {@code TasksHandler} is a recipe for executing arbitrary {@link Runnable} tasks @@ -74,6 +85,7 @@ public class TasksHandler { private StateMachine stateMachine; private final CompositeTasksListener listener = new CompositeTasksListener(); + private final StateMachinePersist persist; /** * Instantiates a new tasks handler. Intentionally private instantiation @@ -81,10 +93,25 @@ public class TasksHandler { * * @param tasks the wrapped tasks * @param listener the tasks listener + * @param taskExecutor the task executor + * @param persist the state machine persist */ - private TasksHandler(List tasks, TasksListener listener, TaskExecutor taskExecutor) { + private TasksHandler(List tasks, TasksListener listener, TaskExecutor taskExecutor, + StateMachinePersist persist) { + this.persist = persist; try { - this.stateMachine = buildStateMachine(tasks, taskExecutor); + stateMachine = buildStateMachine(tasks, taskExecutor); + if (persist != null) { + final LocalStateMachineInterceptor interceptor = new LocalStateMachineInterceptor(persist); + stateMachine.getStateMachineAccessor() + .doWithAllRegions(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.addStateMachineInterceptor(interceptor); + } + }); + } } catch (Exception e) { throw new StateMachineException("Error building state machine from tasks", e); } @@ -114,6 +141,37 @@ public class TasksHandler { stateMachine.sendEvent(EVENT_FIX); } + /** + * Resets state machine states from a backing persistent repository. If + * {@link StateMachinePersist} is not set this method doesn't do anything. + * {@link StateMachine} is stopped before states are reseted from a persistent + * store and started afterwards. + */ + public void resetFromPersistStore() { + if (persist == null) { + // TODO: should we throw or silently return? + return; + } + + final StateMachineContext context; + try { + context = persist.read(null); + } catch (Exception e) { + throw new StateMachineException("Error reading state from persistent store", e); + } + + stateMachine.stop(); + stateMachine.getStateMachineAccessor() + .doWithAllRegions(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.resetStateMachine(context); + } + }); + stateMachine.start(); + } + /** * Adds the tasks listener. * @@ -302,6 +360,7 @@ public class TasksHandler { private final List tasks = new ArrayList(); private TasksListener listener; private TaskExecutor taskExecutor; + private StateMachinePersist persist; /** * Define a top-level task. @@ -336,6 +395,7 @@ public class TasksHandler { * @return the builder for chaining */ public Builder persist(StateMachinePersist persist) { + this.persist = persist; return this; } @@ -369,7 +429,7 @@ public class TasksHandler { * @return the tasks handler */ public TasksHandler build() { - return new TasksHandler(tasks, listener, taskExecutor); + return new TasksHandler(tasks, listener, taskExecutor, persist); } } @@ -384,7 +444,7 @@ public class TasksHandler { } /** - * Gets a loca runnable action. + * Gets a local runnable action. * * @param runnable the runnable * @param id the task id @@ -770,6 +830,54 @@ public class TasksHandler { } + /** + * Local {@link StateMachineInterceptor} persisting state machine states. + */ + private class LocalStateMachineInterceptor extends StateMachineInterceptorAdapter { + + // TODO: should try to find a common way to build context and + // not do tweaks here. + private final StateMachinePersist persist; + private DefaultStateMachineContext currentContext; + private State currentContextState; + private final List> childs = new ArrayList>(); + + public LocalStateMachineInterceptor(StateMachinePersist persist) { + this.persist = persist; + } + + @Override + public void preStateChange(State state, Message message, + Transition transition, StateMachine stateMachine) { + + // skip all other pseudostates than initial + if (state == null || (state.getPseudoState() != null && state.getPseudoState().getKind() != PseudoStateKind.INITIAL)) { + return; + } + + // track root state here and update childs + if (currentContext != null && StateMachineUtils.isSubstate(currentContextState, state)) { + DefaultStateMachineContext context = new DefaultStateMachineContext( + transition != null ? transition.getTarget().getId() : null, message != null ? message.getPayload() + : null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState()); + currentContext.getChilds().add(context); + } else { + childs.clear(); + DefaultStateMachineContext context = new DefaultStateMachineContext( + new ArrayList>(childs), state.getId(), message != null ? message.getPayload() + : null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState()); + currentContext = context; + currentContextState = state; + } + + try { + persist.write(currentContext, null); + } catch (Exception e) { + throw new StateMachineException("Error persisting", e); + } + } + } + /** * Wrapping a {@link Runnable} with a task identifier and parent if task * is a subtask. If parent is null it indicates that a task is a top-level diff --git a/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TasksHandlerTests.java b/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TasksHandlerTests.java index 0b45a140..37682c30 100644 --- a/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TasksHandlerTests.java +++ b/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TasksHandlerTests.java @@ -28,10 +28,12 @@ import java.util.concurrent.TimeUnit; import org.junit.Test; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; +import org.springframework.statemachine.StateMachineContext; import org.springframework.statemachine.listener.StateMachineListenerAdapter; import org.springframework.statemachine.recipes.tasks.TasksHandler; import org.springframework.statemachine.recipes.tasks.TasksHandler.TasksListenerAdapter; import org.springframework.statemachine.state.State; +import org.springframework.statemachine.support.DefaultStateMachineContext; import org.springframework.statemachine.transition.Transition; public class TasksHandlerTests { @@ -309,6 +311,182 @@ public class TasksHandlerTests { assertThat(tasksListener.onTasksContinue, is(1)); } + @Test + public void testPersist1() throws InterruptedException { + TestStateMachinePersist persist = new TestStateMachinePersist(); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", sleepRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + listener.reset(10, 0, 0); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + machine.start(); + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + + persist.reset(5); + + handler.runTasks(); + + assertThat(listener.stateChangedLatch.await(8, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(10)); + assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY)); + Map variables = machine.getExtendedState().getVariables(); + assertThat(variables.size(), is(3)); + + assertThat(persist.writeLatch.await(4, TimeUnit.SECONDS), is(true)); + assertThat(persist.contexts.size(), is(5)); + + for (StateMachineContext context : persist.getContexts()) { + if (context.getState() == "TASKS") { + assertThat(context.getChilds().size(), is(3)); + } else { + assertThat(context.getChilds().size(), is(0)); + } + } + } + + @Test + public void testPersist2() throws InterruptedException { + TestStateMachinePersist persist = new TestStateMachinePersist(); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", failRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + listener.reset(10, 0, 0); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + machine.start(); + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + + persist.reset(6); + + handler.runTasks(); + + assertThat(listener.stateChangedLatch.await(8, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(10)); + assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_ERROR, TasksHandler.STATE_AUTOMATIC)); + Map variables = machine.getExtendedState().getVariables(); + assertThat(variables.size(), is(3)); + + assertThat(persist.writeLatch.await(4, TimeUnit.SECONDS), is(true)); + assertThat(persist.contexts.size(), is(6)); + + for (StateMachineContext context : persist.getContexts()) { + if (context.getState() == "TASKS") { + assertThat(context.getChilds().size(), is(3)); + } else if (context.getState() == "ERROR") { + assertThat(context.getChilds().size(), is(1)); + } else { + assertThat(context.getChilds().size(), is(0)); + } + } + } + + @Test + public void testReset1() throws InterruptedException { + TestStateMachinePersist persist = new TestStateMachinePersist(); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", sleepRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + handler.resetFromPersistStore(); + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + } + + @Test + public void testReset2() throws InterruptedException { + DefaultStateMachineContext child = new DefaultStateMachineContext("MANUAL", null, null, null); + List> childs = new ArrayList>(); + childs.add(child); + DefaultStateMachineContext context = new DefaultStateMachineContext(childs, "ERROR", null, null, null); + TestStateMachinePersist persist = new TestStateMachinePersist(context); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", sleepRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + + handler.resetFromPersistStore(); + + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_ERROR, TasksHandler.STATE_MANUAL)); + } + + @Test + public void testReset3() throws InterruptedException { + List> childs = new ArrayList>(); + DefaultStateMachineContext context = new DefaultStateMachineContext(childs, "ERROR", null, null, null); + TestStateMachinePersist persist = new TestStateMachinePersist(context); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", sleepRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + listener.reset(2, 0, 0); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + + handler.resetFromPersistStore(); + + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + + assertThat(listener.stateChangedLatch.await(4, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(2)); + assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY)); + } + + //@Test + public void testReset4() throws InterruptedException { + // TODO: automaticAction() is not executed when state is reset + DefaultStateMachineContext child = new DefaultStateMachineContext("AUTOMATIC", null, null, null); + List> childs = new ArrayList>(); + childs.add(child); + DefaultStateMachineContext context = new DefaultStateMachineContext(childs, "ERROR", null, null, null); + TestStateMachinePersist persist = new TestStateMachinePersist(context); + TasksHandler handler = TasksHandler.builder() + .task("1", sleepRunnable()) + .task("2", sleepRunnable()) + .task("3", sleepRunnable()) + .persist(persist) + .build(); + + TestListener listener = new TestListener(); + listener.reset(2, 0, 0); + StateMachine machine = handler.getStateMachine(); + machine.addStateListener(listener); + + handler.resetFromPersistStore(); + + assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true)); + + assertThat(listener.stateChangedLatch.await(4, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(2)); + assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY)); + } + private static Runnable sleepRunnable() { return new Runnable() { diff --git a/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TestStateMachinePersist.java b/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TestStateMachinePersist.java new file mode 100644 index 00000000..2ed44b04 --- /dev/null +++ b/spring-statemachine-recipes/src/test/java/org/springframework/statemachine/recipes/TestStateMachinePersist.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.statemachine.recipes; + +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; + +import org.springframework.statemachine.StateMachineContext; +import org.springframework.statemachine.StateMachinePersist; + +public class TestStateMachinePersist implements StateMachinePersist { + + public final ArrayList> contexts = new ArrayList<>(); + public volatile CountDownLatch writeLatch = new CountDownLatch(1); + private StateMachineContext context; + + public TestStateMachinePersist() { + } + + public TestStateMachinePersist(StateMachineContext context) { + this.context = context; + } + + @Override + public void write(StateMachineContext context, Void contextOjb) throws Exception { + synchronized (this) { + contexts.add(context); + } + this.context = context; + writeLatch.countDown(); + } + + @Override + public StateMachineContext read(Void contextOjb) throws Exception { + return context; + } + + public ArrayList> getContexts() { + return contexts; + } + + public void reset(int c1) { + synchronized (this) { + contexts.clear(); + writeLatch = new CountDownLatch(c1); + } + } + +}