diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactory.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactory.java index 3bb492fa02..d924e78ffa 100644 --- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactory.java +++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,11 +35,12 @@ import org.springframework.util.Assert; /** * Given a list of connection factories, serves up {@link TcpConnection}s - * that can iterate over a connection from each factory until the write + * that can iterate over a connection from each factory until the {@code write} * succeeds or the list is exhausted. * * @author Gary Russell * @author Christian Tzolov + * @author Artem Bilan * * @since 2.2 * @@ -163,8 +164,9 @@ public class FailoverClientConnectionFactory extends AbstractClientConnectionFac return sharedConnection; } FailoverTcpConnection failoverTcpConnection = new FailoverTcpConnection(this.factories); - if (getListener() != null) { - failoverTcpConnection.registerListener(getListener()); + TcpListener listener = getListener(); + if (listener != null) { + failoverTcpConnection.registerListener(listener); } failoverTcpConnection.incrementEpoch(); if (shared) { @@ -286,9 +288,7 @@ public class FailoverClientConnectionFactory extends AbstractClientConnectionFac } catch (RuntimeException e) { if (logger.isDebugEnabled()) { - logger.debug(nextFactory + " failed with " - + e.toString() - + ", trying another"); + logger.debug(nextFactory + " failed with " + e + ", trying another"); } if (restartedList && (lastFactoryToTry == null || lastFactoryToTry.equals(nextFactory))) { logger.debug("Failover failed to find a connection"); diff --git a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactoryTests.java b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactoryTests.java index f363cf81c9..2733937d29 100644 --- a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactoryTests.java +++ b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.net.Socket; import java.nio.channels.SocketChannel; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -39,8 +40,6 @@ import org.mockito.InOrder; import org.mockito.Mockito; import org.springframework.beans.factory.BeanFactory; -import org.springframework.context.ApplicationEvent; -import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; @@ -79,26 +78,13 @@ import static org.mockito.Mockito.when; */ public class FailoverClientConnectionFactoryTests { - private static final ApplicationEventPublisher NULL_PUBLISHER = new ApplicationEventPublisher() { - - @Override - public void publishEvent(ApplicationEvent event) { - } - - @Override - public void publishEvent(Object event) { - - } - - }; - @Test public void testFailoverGood() throws Exception { TcpConnectionSupport conn1 = makeMockConnection(); TcpConnectionSupport conn2 = makeMockConnection(); AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1); AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); doThrow(new UncheckedIOException(new IOException("fail"))) @@ -106,7 +92,7 @@ public class FailoverClientConnectionFactoryTests { doAnswer(invocation -> null).when(conn2).send(Mockito.any(Message.class)); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); failoverFactory.getConnection().send(message); Mockito.verify(conn2).send(message); } @@ -129,7 +115,7 @@ public class FailoverClientConnectionFactoryTests { private void testRefreshShared(boolean closeOnRefresh, long interval) throws Exception { AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class); AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); TcpConnectionSupport conn1 = makeMockConnection(); @@ -182,7 +168,7 @@ public class FailoverClientConnectionFactoryTests { TcpConnectionSupport conn2 = makeMockConnection(); AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1); AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); doThrow(new UncheckedIOException(new IOException("fail"))) @@ -191,7 +177,7 @@ public class FailoverClientConnectionFactoryTests { .when(conn2).send(Mockito.any(Message.class)); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() -> failoverFactory.getConnection().send(message)); Mockito.verify(conn2).send(message); @@ -214,7 +200,7 @@ public class FailoverClientConnectionFactoryTests { TcpNetClientConnectionFactory cf1 = new TcpNetClientConnectionFactory("localhost", ss1.getLocalPort()); AbstractClientConnectionFactory cf2 = mock(AbstractClientConnectionFactory.class); doThrow(new UncheckedIOException(new IOException("fail"))).when(cf2).getConnection(); - CountDownLatch latch = new CountDownLatch(2); + CountDownLatch latch = new CountDownLatch(1); cf1.setApplicationEventPublisher(event -> { if (event instanceof TcpConnectionCloseEvent) { latch.countDown(); @@ -223,12 +209,16 @@ public class FailoverClientConnectionFactoryTests { cf2.setApplicationEventPublisher(event -> { }); FailoverClientConnectionFactory fccf = new FailoverClientConnectionFactory(List.of(cf1, cf2)); - fccf.registerListener(msf -> { - latch.countDown(); - return false; - }); + + CompletableFuture> messageCompletableFuture = new CompletableFuture<>(); + fccf.registerListener(messageCompletableFuture::complete); + fccf.start(); fccf.getConnection().send(new GenericMessage<>("test")); + assertThat(messageCompletableFuture) + .succeedsWithin(10, TimeUnit.SECONDS) + .extracting(Message::getPayload) + .isEqualTo("ok".getBytes()); assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() -> fccf.getConnection().send(new GenericMessage<>("test"))); @@ -240,7 +230,7 @@ public class FailoverClientConnectionFactoryTests { TcpConnectionSupport conn2 = makeMockConnection(); AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1); AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); final AtomicBoolean failedOnce = new AtomicBoolean(); @@ -255,7 +245,7 @@ public class FailoverClientConnectionFactoryTests { .when(conn2).send(Mockito.any(Message.class)); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); failoverFactory.getConnection().send(message); Mockito.verify(conn2).send(message); Mockito.verify(conn1, times(2)).send(message); @@ -265,7 +255,7 @@ public class FailoverClientConnectionFactoryTests { public void testFailoverConnectNone() throws Exception { AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class); AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); when(factory1.getConnection()).thenThrow(new UncheckedIOException(new IOException("fail"))); @@ -274,7 +264,7 @@ public class FailoverClientConnectionFactoryTests { when(factory2.isActive()).thenReturn(true); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() -> failoverFactory.getConnection().send(message)); } @@ -283,7 +273,7 @@ public class FailoverClientConnectionFactoryTests { public void testFailoverConnectToFirstAfterTriedAll() throws Exception { AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class); AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); TcpConnectionSupport conn1 = makeMockConnection(); @@ -308,7 +298,7 @@ public class FailoverClientConnectionFactoryTests { TcpConnectionSupport conn2 = makeMockConnection(); AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1); AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2); - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(factory1); factories.add(factory2); final AtomicInteger failCount = new AtomicInteger(); @@ -322,7 +312,7 @@ public class FailoverClientConnectionFactoryTests { .when(conn2).send(Mockito.any(Message.class)); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); assertThatExceptionOfType(UncheckedIOException.class) .isThrownBy(() -> failoverFactory.getConnection().send(message)); failoverFactory.getConnection().send(message); @@ -426,27 +416,27 @@ public class FailoverClientConnectionFactoryTests { cachingFactory2.setBeanName("cache2"); // Failover - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(cachingFactory1); factories.add(cachingFactory2); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); TcpConnection conn1 = failoverFactory.getConnection(); - conn1.send(new GenericMessage("foo1")); + conn1.send(new GenericMessage<>("foo1")); conn1.close(); TcpConnection conn2 = failoverFactory.getConnection(); assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class)) .getTheConnection()) .isSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class)) .getTheConnection()); - conn2.send(new GenericMessage("foo2")); + conn2.send(new GenericMessage<>("foo2")); conn1 = failoverFactory.getConnection(); assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class)) .getTheConnection()) .isNotSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class)) .getTheConnection()); - conn1.send(new GenericMessage("foo3")); + conn1.send(new GenericMessage<>("foo3")); conn1.close(); conn2.close(); assertThat(latch1.await(10, TimeUnit.SECONDS)).isTrue(); @@ -455,8 +445,8 @@ public class FailoverClientConnectionFactoryTests { TestingUtilities.waitUntilFactoryHasThisNumberOfConnections(factory1, 0); conn1 = failoverFactory.getConnection(); conn2 = failoverFactory.getConnection(); - conn1.send(new GenericMessage("foo4")); - conn2.send(new GenericMessage("foo5")); + conn1.send(new GenericMessage<>("foo4")); + conn2.send(new GenericMessage<>("foo5")); conn1.close(); conn2.close(); assertThat(latch2.await(10, TimeUnit.SECONDS)).isTrue(); @@ -467,7 +457,7 @@ public class FailoverClientConnectionFactoryTests { @SuppressWarnings("unchecked") @Test - public void testFailoverCachedWithGateway() throws Exception { + public void testFailoverCachedWithGateway() { final TcpNetServerConnectionFactory server = new TcpNetServerConnectionFactory(0); server.setBeanName("server"); server.afterPropertiesSet(); @@ -490,7 +480,7 @@ public class FailoverClientConnectionFactoryTests { cachingClient.afterPropertiesSet(); // Failover - List clientFactories = new ArrayList(); + List clientFactories = new ArrayList<>(); clientFactories.add(cachingClient); FailoverClientConnectionFactory failoverClient = new FailoverClientConnectionFactory(clientFactories); failoverClient.setSingleUse(true); @@ -505,13 +495,13 @@ public class FailoverClientConnectionFactoryTests { outbound.afterPropertiesSet(); outbound.start(); - outbound.handleMessage(new GenericMessage("foo")); + outbound.handleMessage(new GenericMessage<>("foo")); Message result = (Message) replyChannel.receive(10000); assertThat(result).isNotNull(); assertThat(new String(result.getPayload())).isEqualTo("foo"); // INT-4024 - second reply had bad connection id - outbound.handleMessage(new GenericMessage("foo")); + outbound.handleMessage(new GenericMessage<>("foo")); result = (Message) replyChannel.receive(10000); assertThat(result).isNotNull(); assertThat(new String(result.getPayload())).isEqualTo("foo"); @@ -557,13 +547,13 @@ public class FailoverClientConnectionFactoryTests { cachingFactory2.setBeanName("cache2"); // Failover - List factories = new ArrayList(); + List factories = new ArrayList<>(); factories.add(cachingFactory1); factories.add(cachingFactory2); FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories); failoverFactory.start(); TcpConnection conn1 = failoverFactory.getConnection(); - GenericMessage message = new GenericMessage("foo"); + GenericMessage message = new GenericMessage<>("foo"); conn1.send(message); conn1.close(); TcpConnection conn2 = failoverFactory.getConnection(); @@ -595,9 +585,11 @@ public class FailoverClientConnectionFactoryTests { client2.setTaskExecutor(holder.exec); client1.setBeanName("client1"); client2.setBeanName("client2"); - client1.setApplicationEventPublisher(NULL_PUBLISHER); - client2.setApplicationEventPublisher(NULL_PUBLISHER); - List factories = new ArrayList(); + client1.setApplicationEventPublisher(event -> { + }); + client2.setApplicationEventPublisher(event -> { + }); + List factories = new ArrayList<>(); factories.add(client1); factories.add(client2); FailoverClientConnectionFactory failFactory = new FailoverClientConnectionFactory(factories); @@ -610,10 +602,10 @@ public class FailoverClientConnectionFactoryTests { outGateway.start(); QueueChannel replyChannel = new QueueChannel(); outGateway.setReplyChannel(replyChannel); - Message message = new GenericMessage("foo"); + Message message = new GenericMessage<>("foo"); outGateway.setRemoteTimeout(120000); outGateway.handleMessage(message); - Socket socket = null; + Socket socket; if (!singleUse) { socket = getSocket(client1); port1 = socket.getLocalPort(); @@ -644,12 +636,14 @@ public class FailoverClientConnectionFactoryTests { server2.setTaskExecutor(exec); server1.setBeanName("server1"); server2.setBeanName("server2"); - server1.setApplicationEventPublisher(NULL_PUBLISHER); - server2.setApplicationEventPublisher(NULL_PUBLISHER); + server1.setApplicationEventPublisher(event -> { + }); + server2.setApplicationEventPublisher(event -> { + }); TcpInboundGateway gateway1 = new TcpInboundGateway(); gateway1.setConnectionFactory(server1); SubscribableChannel channel = new DirectChannel(); - final AtomicReference connectionId = new AtomicReference(); + final AtomicReference connectionId = new AtomicReference<>(); channel.subscribe(message -> { connectionId.set((String) message.getHeaders().get(IpHeaders.CONNECTION_ID)); ((MessageChannel) message.getHeaders().getReplyChannel()).send(message); @@ -695,7 +689,9 @@ public class FailoverClientConnectionFactoryTests { } - private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn) throws Exception { + private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn) + throws Exception { + AbstractClientConnectionFactory factory = mock(AbstractClientConnectionFactory.class); when(factory.getConnection()).thenReturn(mockConn); when(factory.isActive()).thenReturn(true); diff --git a/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java b/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java index 7c4111bfb0..c7e1b0c2cd 100644 --- a/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java +++ b/spring-integration-sftp/src/test/java/org/springframework/integration/sftp/session/SftpSessionFactoryTests.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -48,7 +49,6 @@ import org.apache.sshd.sftp.client.impl.AbstractSftpClient; import org.apache.sshd.sftp.server.SftpSubsystemFactory; import org.junit.jupiter.api.Test; -import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SimpleAsyncTaskExecutor; import static org.assertj.core.api.Assertions.assertThat; @@ -133,20 +133,27 @@ public class SftpSessionFactoryTests { sftpSessionFactory.setPassword("pass"); sftpSessionFactory.setAllowUnknownKeys(true); - List concurrentSessions = new ArrayList<>(); + List> concurrentSessions = new ArrayList<>(); - AsyncTaskExecutor asyncTaskExecutor = new SimpleAsyncTaskExecutor(); - for (int i = 0; i < 3; i++) { - asyncTaskExecutor.execute(() -> concurrentSessions.add(sftpSessionFactory.getSession())); + try (var asyncTaskExecutor = new SimpleAsyncTaskExecutor()) { + for (int i = 0; i < 3; i++) { + concurrentSessions.add(asyncTaskExecutor.submitCompletable(sftpSessionFactory::getSession)); + } } - await().atMost(Duration.ofSeconds(30)).until(() -> concurrentSessions.size() == 3); + assertThat(CompletableFuture.allOf(concurrentSessions.toArray(CompletableFuture[]::new))) + .succeedsWithin(Duration.ofSeconds(10)); - assertThat(concurrentSessions.get(0)) - .isNotEqualTo(concurrentSessions.get(1)) - .isNotEqualTo(concurrentSessions.get(2)); + List sftpSessions = concurrentSessions + .stream() + .map(CompletableFuture::join) + .toList(); - assertThat(concurrentSessions.get(1)).isNotEqualTo(concurrentSessions.get(2)); + assertThat(sftpSessions.get(0)) + .isNotEqualTo(sftpSessions.get(1)) + .isNotEqualTo(sftpSessions.get(2)); + + assertThat(sftpSessions.get(1)).isNotEqualTo(sftpSessions.get(2)); sftpSessionFactory.destroy(); }