Add better support for region persist

- Various changes to have a better support for persisting
  regions.
- Fixes for TasksHandler.
- Fixes #94
This commit is contained in:
Janne Valkealahti
2015-10-09 17:35:04 +01:00
parent 7072425400
commit 280e77347e
7 changed files with 416 additions and 7 deletions

View File

@@ -287,6 +287,10 @@ public abstract class AbstractStateMachine<S, E> extends StateMachineObjectSuppo
}
stateMachineExecutor.setInitialEnabled(false);
stateMachineExecutor.start();
// assume that state was set/reseted so we need to
// dispatch started event which would net getting
// dispatched via executor
notifyStateMachineStarted(getRelayStateMachine());
return;
}
registerPseudoStateListener();
@@ -467,10 +471,16 @@ public abstract class AbstractStateMachine<S, E> extends StateMachineObjectSuppo
@Override
public void resetStateMachine(StateMachineContext<S, E> stateMachineContext) {
if (stateMachineContext == null) {
return;
}
if (log.isDebugEnabled()) {
log.debug("Request to reset state machine: stateMachine=[" + this + "] stateMachineContext=[" + stateMachineContext + "]");
}
S state = stateMachineContext.getState();
if (state == null) {
return;
}
boolean stateSet = false;
for (State<S, E> s : getStates()) {
for (State<S, E> ss : s.getStates()) {
@@ -480,6 +490,15 @@ public abstract class AbstractStateMachine<S, E> extends StateMachineObjectSuppo
// needed if we only transit to super state or reset regions
if (s.isSubmachineState()) {
StateMachine<S, E> submachine = ((AbstractState<S, E>)s).getSubmachine();
for (final StateMachineContext<S, E> child : stateMachineContext.getChilds()) {
submachine.getStateMachineAccessor().doWithRegion(new StateMachineFunction<StateMachineAccess<S,E>>() {
@Override
public void apply(StateMachineAccess<S, E> function) {
function.resetStateMachine(child);
}
});
}
submachine.start();
} else if (s.isOrthogonal() && stateMachineContext.getChilds() != null) {
Collection<Region<S, E>> regions = ((AbstractState<S, E>)s).getRegions();
@@ -591,10 +610,16 @@ public abstract class AbstractStateMachine<S, E> extends StateMachineObjectSuppo
|| kind == PseudoStateKind.HISTORY_DEEP) {
StateContext<S, E> stateContext = buildStateContext(message, transition, stateMachine);
State<S, E> toState = state.getPseudoState().entry(stateContext);
if (kind == PseudoStateKind.CHOICE) {
callPreStateChangeInterceptors(toState, message, transition, stateMachine);
}
setCurrentState(toState, message, transition, true, stateMachine);
} else if (kind == PseudoStateKind.FORK) {
ForkPseudoState<S, E> fps = (ForkPseudoState<S, E>) state.getPseudoState();
for (State<S, E> ss : fps.getForks()) {
callPreStateChangeInterceptors(ss, message, transition, stateMachine);
setCurrentState(ss, message, transition, false, stateMachine);
}
} else {

View File

@@ -95,8 +95,8 @@ public class DefaultStateMachineContext<S, E> implements StateMachineContext<S,
@Override
public String toString() {
return "DefaultStateMachineContext [state=" + state + ", event=" + event + ", eventHeaders=" + eventHeaders
+ ", extendedState=" + extendedState + "]";
return "DefaultStateMachineContext [childs=" + childs + ", state=" + state + ", event=" + event
+ ", eventHeaders=" + eventHeaders + ", extendedState=" + extendedState + "]";
}
}

View File

@@ -147,4 +147,9 @@ public class StateMachineInterceptorList<S, E> {
return exception;
}
@Override
public String toString() {
return "StateMachineInterceptorList [interceptors=" + interceptors + "]";
}
}

View File

@@ -129,7 +129,7 @@ public class StateMachineResetTests extends AbstractStateMachineTests {
}
@Test
public void testResetRegions() {
public void testResetRegions1() {
context.register(Config2.class);
context.refresh();
@SuppressWarnings("unchecked")
@@ -159,6 +159,37 @@ public class StateMachineResetTests extends AbstractStateMachineTests {
assertThat(machine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31));
}
@Test
public void testResetRegions2() {
context.register(Config2.class);
context.refresh();
@SuppressWarnings("unchecked")
StateMachine<TestStates, TestEvents> machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, StateMachine.class);
DefaultStateMachineContext<TestStates, TestEvents> stateMachineContext1 =
new DefaultStateMachineContext<TestStates, TestEvents>(TestStates.S21, null, null, null);
DefaultStateMachineContext<TestStates, TestEvents> stateMachineContext2 =
new DefaultStateMachineContext<TestStates, TestEvents>(TestStates.S31, null, null, null);
List<StateMachineContext<TestStates, TestEvents>> childs = new ArrayList<StateMachineContext<TestStates,TestEvents>>();
childs.add(stateMachineContext1);
childs.add(stateMachineContext2);
DefaultStateMachineContext<TestStates, TestEvents> stateMachineContext =
new DefaultStateMachineContext<TestStates, TestEvents>(childs, TestStates.S2, null, null, null);
machine.getStateMachineAccessor().doWithAllRegions(new StateMachineFunction<StateMachineAccess<TestStates, TestEvents>>() {
@Override
public void apply(StateMachineAccess<TestStates, TestEvents> function) {
function.resetStateMachine(stateMachineContext);
}
});
machine.start();
assertThat(machine.getState().getIds(), containsInAnyOrder(TestStates.S2, TestStates.S21, TestStates.S31));
}
@Configuration
@EnableStateMachine
static class Config1 extends EnumStateMachineConfigurerAdapter<States, Events> {

View File

@@ -25,11 +25,15 @@ import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.statemachine.StateContext;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachineException;
import org.springframework.statemachine.StateMachinePersist;
import org.springframework.statemachine.access.StateMachineAccess;
import org.springframework.statemachine.access.StateMachineFunction;
import org.springframework.statemachine.action.Action;
import org.springframework.statemachine.config.StateMachineBuilder;
import org.springframework.statemachine.config.builders.StateMachineStateConfigurer;
@@ -37,9 +41,16 @@ import org.springframework.statemachine.config.builders.StateMachineTransitionCo
import org.springframework.statemachine.guard.Guard;
import org.springframework.statemachine.listener.AbstractCompositeListener;
import org.springframework.statemachine.recipes.support.RunnableAction;
import org.springframework.statemachine.state.PseudoStateKind;
import org.springframework.statemachine.state.State;
import org.springframework.statemachine.support.DefaultStateMachineContext;
import org.springframework.statemachine.support.StateMachineInterceptor;
import org.springframework.statemachine.support.StateMachineInterceptorAdapter;
import org.springframework.statemachine.support.StateMachineUtils;
import org.springframework.statemachine.support.tree.Tree;
import org.springframework.statemachine.support.tree.Tree.Node;
import org.springframework.statemachine.support.tree.TreeTraverser;
import org.springframework.statemachine.transition.Transition;
/**
* {@code TasksHandler} is a recipe for executing arbitrary {@link Runnable} tasks
@@ -74,6 +85,7 @@ public class TasksHandler {
private StateMachine<String, String> stateMachine;
private final CompositeTasksListener listener = new CompositeTasksListener();
private final StateMachinePersist<String, String, Void> persist;
/**
* Instantiates a new tasks handler. Intentionally private instantiation
@@ -81,10 +93,25 @@ public class TasksHandler {
*
* @param tasks the wrapped tasks
* @param listener the tasks listener
* @param taskExecutor the task executor
* @param persist the state machine persist
*/
private TasksHandler(List<TaskWrapper> tasks, TasksListener listener, TaskExecutor taskExecutor) {
private TasksHandler(List<TaskWrapper> tasks, TasksListener listener, TaskExecutor taskExecutor,
StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
try {
this.stateMachine = buildStateMachine(tasks, taskExecutor);
stateMachine = buildStateMachine(tasks, taskExecutor);
if (persist != null) {
final LocalStateMachineInterceptor interceptor = new LocalStateMachineInterceptor(persist);
stateMachine.getStateMachineAccessor()
.doWithAllRegions(new StateMachineFunction<StateMachineAccess<String, String>>() {
@Override
public void apply(StateMachineAccess<String, String> function) {
function.addStateMachineInterceptor(interceptor);
}
});
}
} catch (Exception e) {
throw new StateMachineException("Error building state machine from tasks", e);
}
@@ -114,6 +141,37 @@ public class TasksHandler {
stateMachine.sendEvent(EVENT_FIX);
}
/**
* Resets state machine states from a backing persistent repository. If
* {@link StateMachinePersist} is not set this method doesn't do anything.
* {@link StateMachine} is stopped before states are reseted from a persistent
* store and started afterwards.
*/
public void resetFromPersistStore() {
if (persist == null) {
// TODO: should we throw or silently return?
return;
}
final StateMachineContext<String, String> context;
try {
context = persist.read(null);
} catch (Exception e) {
throw new StateMachineException("Error reading state from persistent store", e);
}
stateMachine.stop();
stateMachine.getStateMachineAccessor()
.doWithAllRegions(new StateMachineFunction<StateMachineAccess<String, String>>() {
@Override
public void apply(StateMachineAccess<String, String> function) {
function.resetStateMachine(context);
}
});
stateMachine.start();
}
/**
* Adds the tasks listener.
*
@@ -302,6 +360,7 @@ public class TasksHandler {
private final List<TaskWrapper> tasks = new ArrayList<TaskWrapper>();
private TasksListener listener;
private TaskExecutor taskExecutor;
private StateMachinePersist<String, String, Void> persist;
/**
* Define a top-level task.
@@ -336,6 +395,7 @@ public class TasksHandler {
* @return the builder for chaining
*/
public Builder persist(StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
return this;
}
@@ -369,7 +429,7 @@ public class TasksHandler {
* @return the tasks handler
*/
public TasksHandler build() {
return new TasksHandler(tasks, listener, taskExecutor);
return new TasksHandler(tasks, listener, taskExecutor, persist);
}
}
@@ -384,7 +444,7 @@ public class TasksHandler {
}
/**
* Gets a loca runnable action.
* Gets a local runnable action.
*
* @param runnable the runnable
* @param id the task id
@@ -770,6 +830,54 @@ public class TasksHandler {
}
/**
* Local {@link StateMachineInterceptor} persisting state machine states.
*/
private class LocalStateMachineInterceptor extends StateMachineInterceptorAdapter<String, String> {
// TODO: should try to find a common way to build context and
// not do tweaks here.
private final StateMachinePersist<String, String, Void> persist;
private DefaultStateMachineContext<String, String> currentContext;
private State<String, String> currentContextState;
private final List<StateMachineContext<String, String>> childs = new ArrayList<StateMachineContext<String, String>>();
public LocalStateMachineInterceptor(StateMachinePersist<String, String, Void> persist) {
this.persist = persist;
}
@Override
public void preStateChange(State<String, String> state, Message<String> message,
Transition<String, String> transition, StateMachine<String, String> stateMachine) {
// skip all other pseudostates than initial
if (state == null || (state.getPseudoState() != null && state.getPseudoState().getKind() != PseudoStateKind.INITIAL)) {
return;
}
// track root state here and update childs
if (currentContext != null && StateMachineUtils.isSubstate(currentContextState, state)) {
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(
transition != null ? transition.getTarget().getId() : null, message != null ? message.getPayload()
: null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState());
currentContext.getChilds().add(context);
} else {
childs.clear();
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(
new ArrayList<StateMachineContext<String, String>>(childs), state.getId(), message != null ? message.getPayload()
: null, message != null ? message.getHeaders() : null, stateMachine.getExtendedState());
currentContext = context;
currentContextState = state;
}
try {
persist.write(currentContext, null);
} catch (Exception e) {
throw new StateMachineException("Error persisting", e);
}
}
}
/**
* Wrapping a {@link Runnable} with a task identifier and parent if task
* is a subtask. If parent is null it indicates that a task is a top-level

View File

@@ -28,10 +28,12 @@ import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.springframework.statemachine.StateContext;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.listener.StateMachineListenerAdapter;
import org.springframework.statemachine.recipes.tasks.TasksHandler;
import org.springframework.statemachine.recipes.tasks.TasksHandler.TasksListenerAdapter;
import org.springframework.statemachine.state.State;
import org.springframework.statemachine.support.DefaultStateMachineContext;
import org.springframework.statemachine.transition.Transition;
public class TasksHandlerTests {
@@ -309,6 +311,182 @@ public class TasksHandlerTests {
assertThat(tasksListener.onTasksContinue, is(1));
}
@Test
public void testPersist1() throws InterruptedException {
TestStateMachinePersist persist = new TestStateMachinePersist();
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", sleepRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
listener.reset(10, 0, 0);
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
machine.start();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
persist.reset(5);
handler.runTasks();
assertThat(listener.stateChangedLatch.await(8, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedCount, is(10));
assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY));
Map<Object, Object> variables = machine.getExtendedState().getVariables();
assertThat(variables.size(), is(3));
assertThat(persist.writeLatch.await(4, TimeUnit.SECONDS), is(true));
assertThat(persist.contexts.size(), is(5));
for (StateMachineContext<String, String> context : persist.getContexts()) {
if (context.getState() == "TASKS") {
assertThat(context.getChilds().size(), is(3));
} else {
assertThat(context.getChilds().size(), is(0));
}
}
}
@Test
public void testPersist2() throws InterruptedException {
TestStateMachinePersist persist = new TestStateMachinePersist();
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", failRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
listener.reset(10, 0, 0);
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
machine.start();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
persist.reset(6);
handler.runTasks();
assertThat(listener.stateChangedLatch.await(8, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedCount, is(10));
assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_ERROR, TasksHandler.STATE_AUTOMATIC));
Map<Object, Object> variables = machine.getExtendedState().getVariables();
assertThat(variables.size(), is(3));
assertThat(persist.writeLatch.await(4, TimeUnit.SECONDS), is(true));
assertThat(persist.contexts.size(), is(6));
for (StateMachineContext<String, String> context : persist.getContexts()) {
if (context.getState() == "TASKS") {
assertThat(context.getChilds().size(), is(3));
} else if (context.getState() == "ERROR") {
assertThat(context.getChilds().size(), is(1));
} else {
assertThat(context.getChilds().size(), is(0));
}
}
}
@Test
public void testReset1() throws InterruptedException {
TestStateMachinePersist persist = new TestStateMachinePersist();
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", sleepRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
handler.resetFromPersistStore();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
}
@Test
public void testReset2() throws InterruptedException {
DefaultStateMachineContext<String, String> child = new DefaultStateMachineContext<String, String>("MANUAL", null, null, null);
List<StateMachineContext<String, String>> childs = new ArrayList<StateMachineContext<String, String>>();
childs.add(child);
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(childs, "ERROR", null, null, null);
TestStateMachinePersist persist = new TestStateMachinePersist(context);
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", sleepRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
handler.resetFromPersistStore();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_ERROR, TasksHandler.STATE_MANUAL));
}
@Test
public void testReset3() throws InterruptedException {
List<StateMachineContext<String, String>> childs = new ArrayList<StateMachineContext<String, String>>();
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(childs, "ERROR", null, null, null);
TestStateMachinePersist persist = new TestStateMachinePersist(context);
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", sleepRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
listener.reset(2, 0, 0);
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
handler.resetFromPersistStore();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedLatch.await(4, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedCount, is(2));
assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY));
}
//@Test
public void testReset4() throws InterruptedException {
// TODO: automaticAction() is not executed when state is reset
DefaultStateMachineContext<String, String> child = new DefaultStateMachineContext<String, String>("AUTOMATIC", null, null, null);
List<StateMachineContext<String, String>> childs = new ArrayList<StateMachineContext<String, String>>();
childs.add(child);
DefaultStateMachineContext<String, String> context = new DefaultStateMachineContext<String, String>(childs, "ERROR", null, null, null);
TestStateMachinePersist persist = new TestStateMachinePersist(context);
TasksHandler handler = TasksHandler.builder()
.task("1", sleepRunnable())
.task("2", sleepRunnable())
.task("3", sleepRunnable())
.persist(persist)
.build();
TestListener listener = new TestListener();
listener.reset(2, 0, 0);
StateMachine<String, String> machine = handler.getStateMachine();
machine.addStateListener(listener);
handler.resetFromPersistStore();
assertThat(listener.stateMachineStartedLatch.await(1, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedLatch.await(4, TimeUnit.SECONDS), is(true));
assertThat(listener.stateChangedCount, is(2));
assertThat(machine.getState().getIds(), contains(TasksHandler.STATE_READY));
}
private static Runnable sleepRunnable() {
return new Runnable() {

View File

@@ -0,0 +1,62 @@
/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.statemachine.recipes;
import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachinePersist;
public class TestStateMachinePersist implements StateMachinePersist<String, String, Void> {
public final ArrayList<StateMachineContext<String, String>> contexts = new ArrayList<>();
public volatile CountDownLatch writeLatch = new CountDownLatch(1);
private StateMachineContext<String, String> context;
public TestStateMachinePersist() {
}
public TestStateMachinePersist(StateMachineContext<String, String> context) {
this.context = context;
}
@Override
public void write(StateMachineContext<String, String> context, Void contextOjb) throws Exception {
synchronized (this) {
contexts.add(context);
}
this.context = context;
writeLatch.countDown();
}
@Override
public StateMachineContext<String, String> read(Void contextOjb) throws Exception {
return context;
}
public ArrayList<StateMachineContext<String, String>> getContexts() {
return contexts;
}
public void reset(int c1) {
synchronized (this) {
contexts.clear();
writeLatch = new CountDownLatch(c1);
}
}
}