diff --git a/spring-core/src/main/java/org/springframework/core/serializer/DefaultDeserializer.java b/spring-core/src/main/java/org/springframework/core/serializer/DefaultDeserializer.java index 07bf246b05..256342d1f9 100644 --- a/spring-core/src/main/java/org/springframework/core/serializer/DefaultDeserializer.java +++ b/spring-core/src/main/java/org/springframework/core/serializer/DefaultDeserializer.java @@ -67,8 +67,7 @@ public class DefaultDeserializer implements Deserializer { @Override @SuppressWarnings("resource") public Object deserialize(InputStream inputStream) throws IOException { - ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(inputStream, this.classLoader); - try { + try (ConfigurableObjectInputStream objectInputStream = new ConfigurableObjectInputStream(inputStream, this.classLoader)){ return objectInputStream.readObject(); } catch (ClassNotFoundException ex) { diff --git a/spring-core/src/test/java/org/springframework/core/serializer/SerializationConverterTests.java b/spring-core/src/test/java/org/springframework/core/serializer/SerializationConverterTests.java index f57abf820e..446cef5090 100644 --- a/spring-core/src/test/java/org/springframework/core/serializer/SerializationConverterTests.java +++ b/spring-core/src/test/java/org/springframework/core/serializer/SerializationConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,24 @@ package org.springframework.core.serializer; +import java.io.ByteArrayInputStream; import java.io.NotSerializableException; import java.io.Serializable; import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; +import org.mockito.Mockito; +import org.springframework.core.ConfigurableObjectInputStream; import org.springframework.core.serializer.support.DeserializingConverter; import org.springframework.core.serializer.support.SerializationFailedException; import org.springframework.core.serializer.support.SerializingConverter; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; + /** * @author Gary Russell @@ -43,27 +50,64 @@ class SerializationConverterTests { assertThat(fromBytes.convert(bytes)).isEqualTo("Testing"); } + @Test + void serializeAndDeserializeStringWithCustomSerializer() { + SerializingConverter toBytes = new SerializingConverter(new DefaultSerializer()); + byte[] bytes = toBytes.convert("Testing"); + DeserializingConverter fromBytes = new DeserializingConverter(); + assertThat(fromBytes.convert(bytes)).isEqualTo("Testing"); + } + @Test void nonSerializableObject() { SerializingConverter toBytes = new SerializingConverter(); - assertThatExceptionOfType(SerializationFailedException.class).isThrownBy(() -> - toBytes.convert(new Object())) - .withCauseInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(SerializationFailedException.class) + .isThrownBy(() -> toBytes.convert(new Object())) + .withCauseInstanceOf(IllegalArgumentException.class); } @Test void nonSerializableField() { SerializingConverter toBytes = new SerializingConverter(); - assertThatExceptionOfType(SerializationFailedException.class).isThrownBy(() -> - toBytes.convert(new UnSerializable())) - .withCauseInstanceOf(NotSerializableException.class); + assertThatExceptionOfType(SerializationFailedException.class) + .isThrownBy(() -> toBytes.convert(new UnSerializable())) + .withCauseInstanceOf(NotSerializableException.class); } @Test void deserializationFailure() { DeserializingConverter fromBytes = new DeserializingConverter(); - assertThatExceptionOfType(SerializationFailedException.class).isThrownBy(() -> - fromBytes.convert("Junk".getBytes())); + assertThatExceptionOfType(SerializationFailedException.class) + .isThrownBy(() -> fromBytes.convert("Junk".getBytes())); + } + + @Test + void deserializationWithClassLoader() { + DeserializingConverter fromBytes = new DeserializingConverter(this.getClass().getClassLoader()); + SerializingConverter toBytes = new SerializingConverter(); + String expected = "SPRING FRAMEWORK"; + assertThat(fromBytes.convert(toBytes.convert(expected))).isEqualTo(expected); + } + + @Test + void deserializationWithDeserializer() { + DeserializingConverter fromBytes = new DeserializingConverter(new DefaultDeserializer()); + SerializingConverter toBytes = new SerializingConverter(); + String expected = "SPRING FRAMEWORK"; + assertThat(fromBytes.convert(toBytes.convert(expected))).isEqualTo(expected); + } + + @Test + void deserializationIOException() { + try (MockedConstruction mocked = Mockito.mockConstruction( + ConfigurableObjectInputStream.class, (mock, context) -> given(mock.readObject()) + .willThrow(new ClassNotFoundException()))) { + DefaultDeserializer defaultSerializer = new DefaultDeserializer(this.getClass().getClassLoader()); + assertThat(mocked).isNotNull(); + assertThatThrownBy(() -> defaultSerializer.deserialize( + new ByteArrayInputStream("test".getBytes()))) + .hasMessage("Failed to deserialize object type"); + } } diff --git a/spring-core/src/test/java/org/springframework/core/serializer/SerializerTests.java b/spring-core/src/test/java/org/springframework/core/serializer/SerializerTests.java new file mode 100644 index 0000000000..38233499fb --- /dev/null +++ b/spring-core/src/test/java/org/springframework/core/serializer/SerializerTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.serializer; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.serializer.support.SerializationDelegate; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + + +class SerializerTests { + + private static final String SPRING_FRAMEWORK = "Spring Framework"; + + + @Test + void serializeToByteArray() throws IOException { + SpyStringSerializer serializer = new SpyStringSerializer(); + serializer.serializeToByteArray(SPRING_FRAMEWORK); + assertThat(serializer.expectedObject).isEqualTo(SPRING_FRAMEWORK); + assertThat(serializer.expectedOs).isNotNull(); + } + + @Test + void deserializeToByteArray() throws IOException { + SpyStringDeserializer deserializer = new SpyStringDeserializer(); + deserializer.deserializeFromByteArray(SPRING_FRAMEWORK.getBytes()); + assertThat(deserializer.expectedObject).isEqualTo(SPRING_FRAMEWORK); + } + + @Test + void serializationDelegate() throws IOException { + SerializationDelegate delegate = new SerializationDelegate(new DefaultSerializer(), new DefaultDeserializer()); + byte[] serializedObj = delegate.serializeToByteArray(SPRING_FRAMEWORK); + Object deserializedObj = delegate.deserialize(new ByteArrayInputStream(serializedObj)); + assertThat(deserializedObj).isEqualTo(SPRING_FRAMEWORK); + } + + @Test + void serializationDelegateWithClassLoader() throws IOException { + SerializationDelegate delegate = new SerializationDelegate(this.getClass().getClassLoader()); + byte[] serializedObj = delegate.serializeToByteArray(SPRING_FRAMEWORK); + Object deserializedObj = delegate.deserialize(new ByteArrayInputStream(serializedObj)); + assertThat(deserializedObj).isEqualTo(SPRING_FRAMEWORK); + } + + static class SpyStringSerializer implements Serializer { + T expectedObject; + OutputStream expectedOs; + + @Override + public void serialize(T object, OutputStream outputStream) { + expectedObject = object; + expectedOs = outputStream; + } + } + + static class SpyStringDeserializer implements Deserializer { + Object expectedObject; + + + @Override + public String deserialize(InputStream inputStream) { + expectedObject = SPRING_FRAMEWORK; + return SPRING_FRAMEWORK; + } + } +}