diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/core/MessageHeaders.java b/org.springframework.integration/src/main/java/org/springframework/integration/core/MessageHeaders.java index e2ac8c19cf..12701a81ba 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/core/MessageHeaders.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/core/MessageHeaders.java @@ -16,14 +16,22 @@ package org.springframework.integration.core; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.Serializable; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + /** * The headers for a {@link Message}. * @@ -32,6 +40,8 @@ import java.util.UUID; */ public final class MessageHeaders implements Map, Serializable { + private static final Log logger = LogFactory.getLog(MessageHeaders.class); + private static final String PREFIX = "spring.integration."; public static final String ID = PREFIX + "id"; @@ -110,10 +120,8 @@ public final class MessageHeaders implements Map, Serializable { return null; } if (!type.isAssignableFrom(value.getClass())) { - throw new IllegalArgumentException( - "Incorrect type specified for header '" + key - + "'. Expected [" + type + "] but actual type is [" - + value.getClass() + "]"); + throw new IllegalArgumentException("Incorrect type specified for header '" + key + + "'. Expected [" + type + "] but actual type is [" + value.getClass() + "]"); } return (T) value; } @@ -193,4 +201,28 @@ public final class MessageHeaders implements Map, Serializable { throw new UnsupportedOperationException("MessageHeaders is immutable."); } + /* + * Serialization methods + */ + + private void writeObject(ObjectOutputStream out) throws IOException { + List keysToRemove = new ArrayList(); + for (Map.Entry entry : this.headers.entrySet()) { + if (!(entry.getValue() instanceof Serializable)) { + keysToRemove.add(entry.getKey()); + } + } + for (String key : keysToRemove) { + if (logger.isWarnEnabled()) { + logger.warn("removing non-serializable header: " + key); + } + this.headers.remove(key); + } + out.defaultWriteObject(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + } + } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java b/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java index 3cafeef44f..e94d987962 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java @@ -51,7 +51,7 @@ public final class MessageBuilder { this.payload = payload; this.originalMessage = originalMessage; if (originalMessage != null) { - this.headers.putAll(originalMessage.getHeaders()); + this.copyHeaders(originalMessage.getHeaders()); } } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/message/MessageHeadersTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/message/MessageHeadersTests.java index a1cb1dbd26..a474f1b726 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/message/MessageHeadersTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/message/MessageHeadersTests.java @@ -21,6 +21,10 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.HashMap; import java.util.Map; import java.util.Set; @@ -92,4 +96,40 @@ public class MessageHeadersTests { assertTrue(keys.contains("key2")); } + @Test + public void serializeWithAllSerializableHeaders() throws Exception { + Map map = new HashMap(); + map.put("name", "joe"); + map.put("age", 42); + MessageHeaders input = new MessageHeaders(map); + MessageHeaders output = (MessageHeaders) serializeAndDeserialize(input); + assertEquals("joe", output.get("name")); + assertEquals(42, output.get("age")); + } + + @Test + public void serializeWithNonSerializableHeader() throws Exception { + Object address = new Object(); + Map map = new HashMap(); + map.put("name", "joe"); + map.put("address", address); + MessageHeaders input = new MessageHeaders(map); + MessageHeaders output = (MessageHeaders) serializeAndDeserialize(input); + assertEquals("joe", output.get("name")); + assertNull(output.get("address")); + } + + + private static Object serializeAndDeserialize(Object object) throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(baos); + out.writeObject(object); + out.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + ObjectInputStream in = new ObjectInputStream(bais); + Object result = in.readObject(); + in.close(); + return result; + } + }