Fix some race conditions in tests

* `SftpSessionFactoryTests.concurrentGetSessionDoesntCauseFailure()`
may not report properly into an `ArrayList` from another thread.
Use `asyncTaskExecutor.submitCompletable()` instead and deal with their result afterward.
* Use `CompletableFuture` for the `TcpListener` logic in the `FailoverClientConnectionFactoryTests.failoverAllDeadAfterSuccess()`
This commit is contained in:
Artem Bilan
2025-03-20 10:19:36 -04:00
parent d73ded52c9
commit 08400df280
3 changed files with 74 additions and 71 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -35,11 +35,12 @@ import org.springframework.util.Assert;
/**
* Given a list of connection factories, serves up {@link TcpConnection}s
* that can iterate over a connection from each factory until the write
* that can iterate over a connection from each factory until the {@code write}
* succeeds or the list is exhausted.
*
* @author Gary Russell
* @author Christian Tzolov
* @author Artem Bilan
*
* @since 2.2
*
@@ -163,8 +164,9 @@ public class FailoverClientConnectionFactory extends AbstractClientConnectionFac
return sharedConnection;
}
FailoverTcpConnection failoverTcpConnection = new FailoverTcpConnection(this.factories);
if (getListener() != null) {
failoverTcpConnection.registerListener(getListener());
TcpListener listener = getListener();
if (listener != null) {
failoverTcpConnection.registerListener(listener);
}
failoverTcpConnection.incrementEpoch();
if (shared) {
@@ -286,9 +288,7 @@ public class FailoverClientConnectionFactory extends AbstractClientConnectionFac
}
catch (RuntimeException e) {
if (logger.isDebugEnabled()) {
logger.debug(nextFactory + " failed with "
+ e.toString()
+ ", trying another");
logger.debug(nextFactory + " failed with " + e + ", trying another");
}
if (restartedList && (lastFactoryToTry == null || lastFactoryToTry.equals(nextFactory))) {
logger.debug("Failover failed to find a connection");

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ import java.net.Socket;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
@@ -39,8 +40,6 @@ import org.mockito.InOrder;
import org.mockito.Mockito;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
@@ -79,26 +78,13 @@ import static org.mockito.Mockito.when;
*/
public class FailoverClientConnectionFactoryTests {
private static final ApplicationEventPublisher NULL_PUBLISHER = new ApplicationEventPublisher() {
@Override
public void publishEvent(ApplicationEvent event) {
}
@Override
public void publishEvent(Object event) {
}
};
@Test
public void testFailoverGood() throws Exception {
TcpConnectionSupport conn1 = makeMockConnection();
TcpConnectionSupport conn2 = makeMockConnection();
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
doThrow(new UncheckedIOException(new IOException("fail")))
@@ -106,7 +92,7 @@ public class FailoverClientConnectionFactoryTests {
doAnswer(invocation -> null).when(conn2).send(Mockito.any(Message.class));
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
failoverFactory.getConnection().send(message);
Mockito.verify(conn2).send(message);
}
@@ -129,7 +115,7 @@ public class FailoverClientConnectionFactoryTests {
private void testRefreshShared(boolean closeOnRefresh, long interval) throws Exception {
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
TcpConnectionSupport conn1 = makeMockConnection();
@@ -182,7 +168,7 @@ public class FailoverClientConnectionFactoryTests {
TcpConnectionSupport conn2 = makeMockConnection();
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
doThrow(new UncheckedIOException(new IOException("fail")))
@@ -191,7 +177,7 @@ public class FailoverClientConnectionFactoryTests {
.when(conn2).send(Mockito.any(Message.class));
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
failoverFactory.getConnection().send(message));
Mockito.verify(conn2).send(message);
@@ -214,7 +200,7 @@ public class FailoverClientConnectionFactoryTests {
TcpNetClientConnectionFactory cf1 = new TcpNetClientConnectionFactory("localhost", ss1.getLocalPort());
AbstractClientConnectionFactory cf2 = mock(AbstractClientConnectionFactory.class);
doThrow(new UncheckedIOException(new IOException("fail"))).when(cf2).getConnection();
CountDownLatch latch = new CountDownLatch(2);
CountDownLatch latch = new CountDownLatch(1);
cf1.setApplicationEventPublisher(event -> {
if (event instanceof TcpConnectionCloseEvent) {
latch.countDown();
@@ -223,12 +209,16 @@ public class FailoverClientConnectionFactoryTests {
cf2.setApplicationEventPublisher(event -> {
});
FailoverClientConnectionFactory fccf = new FailoverClientConnectionFactory(List.of(cf1, cf2));
fccf.registerListener(msf -> {
latch.countDown();
return false;
});
CompletableFuture<Message<?>> messageCompletableFuture = new CompletableFuture<>();
fccf.registerListener(messageCompletableFuture::complete);
fccf.start();
fccf.getConnection().send(new GenericMessage<>("test"));
assertThat(messageCompletableFuture)
.succeedsWithin(10, TimeUnit.SECONDS)
.extracting(Message::getPayload)
.isEqualTo("ok".getBytes());
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
fccf.getConnection().send(new GenericMessage<>("test")));
@@ -240,7 +230,7 @@ public class FailoverClientConnectionFactoryTests {
TcpConnectionSupport conn2 = makeMockConnection();
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
final AtomicBoolean failedOnce = new AtomicBoolean();
@@ -255,7 +245,7 @@ public class FailoverClientConnectionFactoryTests {
.when(conn2).send(Mockito.any(Message.class));
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
failoverFactory.getConnection().send(message);
Mockito.verify(conn2).send(message);
Mockito.verify(conn1, times(2)).send(message);
@@ -265,7 +255,7 @@ public class FailoverClientConnectionFactoryTests {
public void testFailoverConnectNone() throws Exception {
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
when(factory1.getConnection()).thenThrow(new UncheckedIOException(new IOException("fail")));
@@ -274,7 +264,7 @@ public class FailoverClientConnectionFactoryTests {
when(factory2.isActive()).thenReturn(true);
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
failoverFactory.getConnection().send(message));
}
@@ -283,7 +273,7 @@ public class FailoverClientConnectionFactoryTests {
public void testFailoverConnectToFirstAfterTriedAll() throws Exception {
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
TcpConnectionSupport conn1 = makeMockConnection();
@@ -308,7 +298,7 @@ public class FailoverClientConnectionFactoryTests {
TcpConnectionSupport conn2 = makeMockConnection();
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(factory1);
factories.add(factory2);
final AtomicInteger failCount = new AtomicInteger();
@@ -322,7 +312,7 @@ public class FailoverClientConnectionFactoryTests {
.when(conn2).send(Mockito.any(Message.class));
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
assertThatExceptionOfType(UncheckedIOException.class)
.isThrownBy(() -> failoverFactory.getConnection().send(message));
failoverFactory.getConnection().send(message);
@@ -426,27 +416,27 @@ public class FailoverClientConnectionFactoryTests {
cachingFactory2.setBeanName("cache2");
// Failover
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(cachingFactory1);
factories.add(cachingFactory2);
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
TcpConnection conn1 = failoverFactory.getConnection();
conn1.send(new GenericMessage<String>("foo1"));
conn1.send(new GenericMessage<>("foo1"));
conn1.close();
TcpConnection conn2 = failoverFactory.getConnection();
assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class))
.getTheConnection())
.isSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class))
.getTheConnection());
conn2.send(new GenericMessage<String>("foo2"));
conn2.send(new GenericMessage<>("foo2"));
conn1 = failoverFactory.getConnection();
assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class))
.getTheConnection())
.isNotSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class))
.getTheConnection());
conn1.send(new GenericMessage<String>("foo3"));
conn1.send(new GenericMessage<>("foo3"));
conn1.close();
conn2.close();
assertThat(latch1.await(10, TimeUnit.SECONDS)).isTrue();
@@ -455,8 +445,8 @@ public class FailoverClientConnectionFactoryTests {
TestingUtilities.waitUntilFactoryHasThisNumberOfConnections(factory1, 0);
conn1 = failoverFactory.getConnection();
conn2 = failoverFactory.getConnection();
conn1.send(new GenericMessage<String>("foo4"));
conn2.send(new GenericMessage<String>("foo5"));
conn1.send(new GenericMessage<>("foo4"));
conn2.send(new GenericMessage<>("foo5"));
conn1.close();
conn2.close();
assertThat(latch2.await(10, TimeUnit.SECONDS)).isTrue();
@@ -467,7 +457,7 @@ public class FailoverClientConnectionFactoryTests {
@SuppressWarnings("unchecked")
@Test
public void testFailoverCachedWithGateway() throws Exception {
public void testFailoverCachedWithGateway() {
final TcpNetServerConnectionFactory server = new TcpNetServerConnectionFactory(0);
server.setBeanName("server");
server.afterPropertiesSet();
@@ -490,7 +480,7 @@ public class FailoverClientConnectionFactoryTests {
cachingClient.afterPropertiesSet();
// Failover
List<AbstractClientConnectionFactory> clientFactories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> clientFactories = new ArrayList<>();
clientFactories.add(cachingClient);
FailoverClientConnectionFactory failoverClient = new FailoverClientConnectionFactory(clientFactories);
failoverClient.setSingleUse(true);
@@ -505,13 +495,13 @@ public class FailoverClientConnectionFactoryTests {
outbound.afterPropertiesSet();
outbound.start();
outbound.handleMessage(new GenericMessage<String>("foo"));
outbound.handleMessage(new GenericMessage<>("foo"));
Message<byte[]> result = (Message<byte[]>) replyChannel.receive(10000);
assertThat(result).isNotNull();
assertThat(new String(result.getPayload())).isEqualTo("foo");
// INT-4024 - second reply had bad connection id
outbound.handleMessage(new GenericMessage<String>("foo"));
outbound.handleMessage(new GenericMessage<>("foo"));
result = (Message<byte[]>) replyChannel.receive(10000);
assertThat(result).isNotNull();
assertThat(new String(result.getPayload())).isEqualTo("foo");
@@ -557,13 +547,13 @@ public class FailoverClientConnectionFactoryTests {
cachingFactory2.setBeanName("cache2");
// Failover
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(cachingFactory1);
factories.add(cachingFactory2);
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
failoverFactory.start();
TcpConnection conn1 = failoverFactory.getConnection();
GenericMessage<String> message = new GenericMessage<String>("foo");
GenericMessage<String> message = new GenericMessage<>("foo");
conn1.send(message);
conn1.close();
TcpConnection conn2 = failoverFactory.getConnection();
@@ -595,9 +585,11 @@ public class FailoverClientConnectionFactoryTests {
client2.setTaskExecutor(holder.exec);
client1.setBeanName("client1");
client2.setBeanName("client2");
client1.setApplicationEventPublisher(NULL_PUBLISHER);
client2.setApplicationEventPublisher(NULL_PUBLISHER);
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
client1.setApplicationEventPublisher(event -> {
});
client2.setApplicationEventPublisher(event -> {
});
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
factories.add(client1);
factories.add(client2);
FailoverClientConnectionFactory failFactory = new FailoverClientConnectionFactory(factories);
@@ -610,10 +602,10 @@ public class FailoverClientConnectionFactoryTests {
outGateway.start();
QueueChannel replyChannel = new QueueChannel();
outGateway.setReplyChannel(replyChannel);
Message<String> message = new GenericMessage<String>("foo");
Message<String> message = new GenericMessage<>("foo");
outGateway.setRemoteTimeout(120000);
outGateway.handleMessage(message);
Socket socket = null;
Socket socket;
if (!singleUse) {
socket = getSocket(client1);
port1 = socket.getLocalPort();
@@ -644,12 +636,14 @@ public class FailoverClientConnectionFactoryTests {
server2.setTaskExecutor(exec);
server1.setBeanName("server1");
server2.setBeanName("server2");
server1.setApplicationEventPublisher(NULL_PUBLISHER);
server2.setApplicationEventPublisher(NULL_PUBLISHER);
server1.setApplicationEventPublisher(event -> {
});
server2.setApplicationEventPublisher(event -> {
});
TcpInboundGateway gateway1 = new TcpInboundGateway();
gateway1.setConnectionFactory(server1);
SubscribableChannel channel = new DirectChannel();
final AtomicReference<String> connectionId = new AtomicReference<String>();
final AtomicReference<String> connectionId = new AtomicReference<>();
channel.subscribe(message -> {
connectionId.set((String) message.getHeaders().get(IpHeaders.CONNECTION_ID));
((MessageChannel) message.getHeaders().getReplyChannel()).send(message);
@@ -695,7 +689,9 @@ public class FailoverClientConnectionFactoryTests {
}
private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn) throws Exception {
private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn)
throws Exception {
AbstractClientConnectionFactory factory = mock(AbstractClientConnectionFactory.class);
when(factory.getConnection()).thenReturn(mockConn);
when(factory.isActive()).thenReturn(true);

View File

@@ -24,6 +24,7 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -48,7 +49,6 @@ import org.apache.sshd.sftp.client.impl.AbstractSftpClient;
import org.apache.sshd.sftp.server.SftpSubsystemFactory;
import org.junit.jupiter.api.Test;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import static org.assertj.core.api.Assertions.assertThat;
@@ -133,20 +133,27 @@ public class SftpSessionFactoryTests {
sftpSessionFactory.setPassword("pass");
sftpSessionFactory.setAllowUnknownKeys(true);
List<SftpSession> concurrentSessions = new ArrayList<>();
List<CompletableFuture<SftpSession>> concurrentSessions = new ArrayList<>();
AsyncTaskExecutor asyncTaskExecutor = new SimpleAsyncTaskExecutor();
for (int i = 0; i < 3; i++) {
asyncTaskExecutor.execute(() -> concurrentSessions.add(sftpSessionFactory.getSession()));
try (var asyncTaskExecutor = new SimpleAsyncTaskExecutor()) {
for (int i = 0; i < 3; i++) {
concurrentSessions.add(asyncTaskExecutor.submitCompletable(sftpSessionFactory::getSession));
}
}
await().atMost(Duration.ofSeconds(30)).until(() -> concurrentSessions.size() == 3);
assertThat(CompletableFuture.allOf(concurrentSessions.toArray(CompletableFuture[]::new)))
.succeedsWithin(Duration.ofSeconds(10));
assertThat(concurrentSessions.get(0))
.isNotEqualTo(concurrentSessions.get(1))
.isNotEqualTo(concurrentSessions.get(2));
List<SftpSession> sftpSessions = concurrentSessions
.stream()
.map(CompletableFuture::join)
.toList();
assertThat(concurrentSessions.get(1)).isNotEqualTo(concurrentSessions.get(2));
assertThat(sftpSessions.get(0))
.isNotEqualTo(sftpSessions.get(1))
.isNotEqualTo(sftpSessions.get(2));
assertThat(sftpSessions.get(1)).isNotEqualTo(sftpSessions.get(2));
sftpSessionFactory.destroy();
}