diff --git a/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerHandlerMethodTests.java b/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerHandlerMethodTests.java index 5bb7a2647..61f09a9ab 100644 --- a/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerHandlerMethodTests.java +++ b/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerHandlerMethodTests.java @@ -30,10 +30,14 @@ import org.springframework.cloud.stream.annotation.EnableBinding; import org.springframework.cloud.stream.annotation.Input; import org.springframework.cloud.stream.annotation.Output; import org.springframework.cloud.stream.annotation.StreamListener; +import org.springframework.cloud.stream.binding.StreamListenerErrorMessages; import org.springframework.cloud.stream.messaging.Processor; import org.springframework.cloud.stream.messaging.Sink; import org.springframework.cloud.stream.test.binder.MessageCollector; import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.integration.annotation.Router; +import org.springframework.integration.channel.DirectChannel; +import org.springframework.integration.support.DefaultMessageBuilderFactory; import org.springframework.integration.support.MessageBuilder; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -84,7 +88,22 @@ public class StreamListenerHandlerMethodTests { MessageCollector messageCollector = context.getBean(MessageCollector.class); Message result = messageCollector.forChannel(processor.output()).poll(1000, TimeUnit.MILLISECONDS); assertThat(result).isNotNull(); - assertThat(result.getPayload()).isEqualTo(result.getPayload().toString().toUpperCase()); + assertThat(result.getPayload()).isEqualTo(testMessage.toUpperCase()); + context.close(); + } + + @Test + public void testStreamListenerMethodWithTargetBeanFromOutside() throws Exception { + ConfigurableApplicationContext context = SpringApplication.run(TestStreamListenerMethodWithTargetBeanFromOutside.class, "--server.port=0"); + Sink sink = context.getBean(Sink.class); + final String testMessageToSend = "testing"; + sink.input().send(MessageBuilder.withPayload(testMessageToSend).build()); + DirectChannel directChannel = (DirectChannel) context.getBean(testMessageToSend.toUpperCase(), MessageChannel.class); + MessageCollector messageCollector = context.getBean(MessageCollector.class); + Message result = messageCollector.forChannel(directChannel).poll(1000, TimeUnit.MILLISECONDS); + sink.input().send(MessageBuilder.withPayload(testMessageToSend).build()); + assertThat(result).isNotNull(); + assertThat(result.getPayload()).isEqualTo(testMessageToSend.toUpperCase()); context.close(); } @@ -139,8 +158,8 @@ public class StreamListenerHandlerMethodTests { fail("Exception expected on using invalid inbound name"); } catch (BeanCreationException e) { - assertThat(e.getCause()).isInstanceOf(NoSuchBeanDefinitionException.class); - assertThat(e.getCause()).hasMessageContaining("'invalid'"); + assertThat(e.getCause()).isInstanceOf(IllegalArgumentException.class); + assertThat(e.getCause()).hasMessageContaining(StreamListenerErrorMessages.INVALID_DECLARATIVE_METHOD_PARAMETERS); } } @@ -303,6 +322,24 @@ public class StreamListenerHandlerMethodTests { } } + @EnableBinding(Sink.class) + @EnableAutoConfiguration + public static class TestStreamListenerMethodWithTargetBeanFromOutside { + + private static final String ROUTER_QUEUE = "routeInstruction"; + + @StreamListener(Sink.INPUT) + @SendTo(ROUTER_QUEUE) + public Message convertMessageBody(Message message) { + return new DefaultMessageBuilderFactory().withPayload(message.getPayload().toUpperCase()).build(); + } + + @Router(inputChannel = ROUTER_QUEUE) + public String route(String message) { + return message.toUpperCase(); + } + } + @EnableBinding({Sink.class}) @EnableAutoConfiguration public static class TestInvalidInputOnMethod { diff --git a/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerWithAnnotatedInputOutputArgsTests.java b/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerWithAnnotatedInputOutputArgsTests.java index 974678d26..1eb0e2fea 100644 --- a/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerWithAnnotatedInputOutputArgsTests.java +++ b/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/config/StreamListenerWithAnnotatedInputOutputArgsTests.java @@ -20,13 +20,14 @@ import java.util.concurrent.TimeUnit; import org.junit.Test; -import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.BeanCreationException; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.cloud.stream.annotation.EnableBinding; import org.springframework.cloud.stream.annotation.Input; import org.springframework.cloud.stream.annotation.Output; import org.springframework.cloud.stream.annotation.StreamListener; +import org.springframework.cloud.stream.binding.StreamListenerErrorMessages; import org.springframework.cloud.stream.messaging.Processor; import org.springframework.cloud.stream.test.binder.MessageCollector; import org.springframework.context.ConfigurableApplicationContext; @@ -57,9 +58,9 @@ public class StreamListenerWithAnnotatedInputOutputArgsTests { public void testInputOutputArgsWithMoreParameters() { try { SpringApplication.run(TestInputOutputArgsWithMoreParameters.class, "--server.port=0"); - fail("Expected exception: "+ INVALID_DECLARATIVE_METHOD_PARAMETERS); + fail("Expected exception: " + INVALID_DECLARATIVE_METHOD_PARAMETERS); } - catch (Exception e) { + catch (BeanCreationException e) { assertThat(e.getMessage()).contains(INVALID_DECLARATIVE_METHOD_PARAMETERS); } } @@ -70,9 +71,9 @@ public class StreamListenerWithAnnotatedInputOutputArgsTests { SpringApplication.run(TestInputOutputArgsWithInvalidBindableTarget.class, "--server.port=0"); fail("Exception expected on using invalid bindable target as method parameter"); } - catch (Exception e) { - assertThat(e.getCause()).isInstanceOf(NoSuchBeanDefinitionException.class); - assertThat(e.getCause()).hasMessageContaining("'invalid'"); + catch (BeanCreationException e) { + assertThat(e.getCause()).isInstanceOf(IllegalArgumentException.class); + assertThat(e.getCause()).hasMessageContaining(StreamListenerErrorMessages.INVALID_DECLARATIVE_METHOD_PARAMETERS); } } diff --git a/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binding/StreamListenerAnnotationBeanPostProcessor.java b/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binding/StreamListenerAnnotationBeanPostProcessor.java index a21be23b0..e54543a4c 100644 --- a/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binding/StreamListenerAnnotationBeanPostProcessor.java +++ b/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binding/StreamListenerAnnotationBeanPostProcessor.java @@ -30,6 +30,7 @@ import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.BeanInitializationException; import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanExpressionContext; @@ -199,7 +200,7 @@ public class StreamListenerAnnotationBeanPostProcessor .getValue(methodParameter.getParameterAnnotation(Input.class)); Assert.isTrue(StringUtils.hasText(inboundName), StreamListenerErrorMessages.INVALID_INBOUND_NAME); Assert.isTrue( - isDeclarativeMethodParameter(this.applicationContext.getBean(inboundName), methodParameter), + isDeclarativeMethodParameter(inboundName, methodParameter), StreamListenerErrorMessages.INVALID_DECLARATIVE_METHOD_PARAMETERS); return true; } @@ -208,33 +209,44 @@ public class StreamListenerAnnotationBeanPostProcessor .getValue(methodParameter.getParameterAnnotation(Output.class)); Assert.isTrue(StringUtils.hasText(outboundName), StreamListenerErrorMessages.INVALID_OUTBOUND_NAME); Assert.isTrue( - isDeclarativeMethodParameter(this.applicationContext.getBean(outboundName), methodParameter), + isDeclarativeMethodParameter(outboundName, methodParameter), StreamListenerErrorMessages.INVALID_DECLARATIVE_METHOD_PARAMETERS); return true; } if (StringUtils.hasText(methodAnnotatedOutboundName)) { - return isDeclarativeMethodParameter(this.applicationContext.getBean(methodAnnotatedOutboundName), - methodParameter); + return isDeclarativeMethodParameter(methodAnnotatedOutboundName, methodParameter); } if (StringUtils.hasText(methodAnnotatedInboundName)) { - return isDeclarativeMethodParameter(this.applicationContext.getBean(methodAnnotatedInboundName), - methodParameter); + return isDeclarativeMethodParameter(methodAnnotatedInboundName, methodParameter); } } return false; } - private boolean isDeclarativeMethodParameter(Object targetBean, MethodParameter methodParameter) { - if (targetBean != null) { + private boolean isDeclarativeMethodParameter(String targetBeanName, MethodParameter methodParameter) { + try { + Class targetBeanClass = this.applicationContext.getType(targetBeanName); if (!methodParameter.getParameterType().equals(Object.class) - && methodParameter.getParameterType().isAssignableFrom(targetBean.getClass())) { + && (targetBeanClass.isAssignableFrom(methodParameter.getParameterType()) || + methodParameter.getParameterType().isAssignableFrom(targetBeanClass))) { return true; } - for (StreamListenerParameterAdapter streamListenerParameterAdapter : this.streamListenerParameterAdapters) { - if (streamListenerParameterAdapter.supports(targetBean.getClass(), methodParameter)) { - return true; + } + catch (NoSuchBeanDefinitionException e) { + // ignore as the bean definition might not exist yet. + } + if (!this.streamListenerParameterAdapters.isEmpty()) { + try { + Object targetBean = this.applicationContext.getBean(targetBeanName); + for (StreamListenerParameterAdapter streamListenerParameterAdapter : this.streamListenerParameterAdapters) { + if (streamListenerParameterAdapter.supports(targetBean.getClass(), methodParameter)) { + return true; + } } } + catch (BeansException e) { + // ignore as the bean definition might not exist yet. + } } return false; }