diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java index 0531a926c5..1af43694a7 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java @@ -17,7 +17,9 @@ package org.springframework.integration.gateway; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.aopalliance.intercept.MethodInterceptor; @@ -167,7 +169,12 @@ public class GatewayProxyFactoryBean extends AbstractEndpoint implements Factory return "gateway proxy for service interface [" + this.serviceInterface + "]"; } if (method.getDeclaringClass().equals(this.serviceInterface)) { - return this.invokeGatewayMethod(invocation); + try { + return this.invokeGatewayMethod(invocation); + } + catch (Exception e) { + rethrowExceptionInThrowsClauseIfPossible(e, invocation.getMethod()); + } } return invocation.proceed(); } @@ -204,6 +211,18 @@ public class GatewayProxyFactoryBean extends AbstractEndpoint implements Factory return (response != null) ? this.typeConverter.convertIfNecessary(response, returnType) : null; } + private void rethrowExceptionInThrowsClauseIfPossible(Throwable originalException, Method method) throws Throwable { + List> exceptionTypes = Arrays.asList(method.getExceptionTypes()); + Throwable t = originalException; + while (t != null) { + if (exceptionTypes.contains(t.getClass())) { + throw t; + } + t = t.getCause(); + } + throw originalException; + } + private MessagingGateway createGatewayForMethod(Method method) throws Exception { SimpleMessagingGateway gateway = new SimpleMessagingGateway( new MethodParameterMessageMapper(method), new SimpleMessageMapper()); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java index e60d883dae..7637cb15ae 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java @@ -19,6 +19,7 @@ package org.springframework.integration.gateway; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import java.lang.reflect.Method; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -33,7 +34,10 @@ import org.springframework.integration.channel.PollableChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.core.Message; import org.springframework.integration.core.MessageChannel; +import org.springframework.integration.endpoint.EventDrivenConsumer; +import org.springframework.integration.message.MessageHandler; import org.springframework.integration.message.StringMessage; +import org.springframework.util.ReflectionUtils; /** * @author Mark Fisher @@ -215,6 +219,25 @@ public class GatewayProxyFactoryBeanTests { assertEquals(expected, proxy.toString().substring(0, expected.length())); } + @Test(expected = TestException.class) + public void testCheckedExceptionRethrownAsIs() throws Exception { + GatewayProxyFactoryBean proxyFactory = new GatewayProxyFactoryBean(); + DirectChannel channel = new DirectChannel(); + EventDrivenConsumer consumer = new EventDrivenConsumer(channel, new MessageHandler() { + public void handleMessage(Message message) { + Method method = ReflectionUtils.findMethod( + GatewayProxyFactoryBeanTests.class, "throwTestException"); + ReflectionUtils.invokeMethod(method, null); + } + }); + consumer.start(); + proxyFactory.setDefaultRequestChannel(channel); + proxyFactory.setServiceInterface(TestExceptionThrowingInterface.class); + proxyFactory.afterPropertiesSet(); + TestExceptionThrowingInterface proxy = (TestExceptionThrowingInterface) proxyFactory.getObject(); + proxy.throwCheckedException("test"); + } + private static void startResponder(final PollableChannel requestChannel) { new Thread(new Runnable() { @@ -226,4 +249,20 @@ public class GatewayProxyFactoryBeanTests { }).start(); } + + public static void throwTestException() throws TestException { + throw new TestException(); + } + + + static interface TestExceptionThrowingInterface { + + String throwCheckedException(String s) throws TestException; + } + + + @SuppressWarnings("serial") + static class TestException extends Exception { + } + }