diff --git a/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionalOperatorImpl.java b/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionalOperatorImpl.java index ea2ab1d7db..ec8e22f11b 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionalOperatorImpl.java +++ b/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionalOperatorImpl.java @@ -78,14 +78,11 @@ final class TransactionalOperatorImpl implements TransactionalOperator { return status.flatMapMany(it -> { // This is an around advice: Invoke the next interceptor in the chain. // This will normally result in a target object being invoked. - Flux retVal = Flux.from(action.doInTransaction(it)); - return retVal.onErrorResume(ex -> rollbackOnException(it, ex). - then(Mono.error(ex))).materialize().flatMap(signal -> { - if (signal.isOnComplete()) { - return this.transactionManager.commit(it).materialize(); - } - return Mono.just(signal); - }).dematerialize(); + // Need re-wrapping of ReactiveTransaction until we get hold of the exception + // through usingWhen. + return Flux.usingWhen(Mono.just(it), action::doInTransaction, + this.transactionManager::commit, s -> Mono.empty()) + .onErrorResume(ex -> rollbackOnException(it, ex).then(Mono.error(ex))); }); }) .subscriberContext(TransactionContextManager.getOrCreateContext()) diff --git a/spring-tx/src/test/java/org/springframework/transaction/reactive/TransactionalOperatorTests.java b/spring-tx/src/test/java/org/springframework/transaction/reactive/TransactionalOperatorTests.java index c6308e7e86..9ccf682f01 100644 --- a/spring-tx/src/test/java/org/springframework/transaction/reactive/TransactionalOperatorTests.java +++ b/spring-tx/src/test/java/org/springframework/transaction/reactive/TransactionalOperatorTests.java @@ -59,9 +59,9 @@ public class TransactionalOperatorTests { @Test public void commitWithFlux() { TransactionalOperator operator = TransactionalOperator.create(tm, new DefaultTransactionDefinition()); - Flux.just(true).as(operator::transactional) + Flux.just(1, 2, 3, 4).as(operator::transactional) .as(StepVerifier::create) - .expectNext(true) + .expectNextCount(4) .verifyComplete(); assertTrue(tm.commit); assertFalse(tm.rollback);