From 2437db435fe1f8aefb4f8e7f0a58d74e90386d90 Mon Sep 17 00:00:00 2001 From: Mark Fisher Date: Tue, 1 Feb 2011 14:35:32 -0500 Subject: [PATCH] INT-1396 DefaultInboundRequestMapper now properly handles reading large byte payloads from request input streams --- .../http/DefaultInboundRequestMapper.java | 16 ++-- .../DefaultInboundRequestMapperTests.java | 78 ++++++++++++++++++- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/org.springframework.integration.http/src/main/java/org/springframework/integration/http/DefaultInboundRequestMapper.java b/org.springframework.integration.http/src/main/java/org/springframework/integration/http/DefaultInboundRequestMapper.java index 4c7578f5ed..cb320e0496 100644 --- a/org.springframework.integration.http/src/main/java/org/springframework/integration/http/DefaultInboundRequestMapper.java +++ b/org.springframework.integration.http/src/main/java/org/springframework/integration/http/DefaultInboundRequestMapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2009 the original author or authors. + * Copyright 2002-2011 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.integration.http; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; @@ -225,7 +226,8 @@ public class DefaultInboundRequestMapper implements InboundRequestMapper { } private byte[] createPayloadFromInputStream(HttpServletRequest request) throws Exception { - InputStream stream = request.getInputStream(); + int bufferSize = 4096; + InputStream in = request.getInputStream(); int length = request.getContentLength(); if (length == -1) { throw new ResponseStatusCodeException(HttpServletResponse.SC_LENGTH_REQUIRED); @@ -234,9 +236,13 @@ public class DefaultInboundRequestMapper implements InboundRequestMapper { logger.debug("received " + request.getMethod() + " request, " + "creating byte array payload with content lenth: " + length); } - byte[] bytes = new byte[length]; - stream.read(bytes, 0, length); - return bytes; + ByteArrayOutputStream out = new ByteArrayOutputStream(bufferSize); + byte[] buffer = new byte[bufferSize]; + int bytesRead = -1; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + return out.toByteArray(); } private void populateHeaders(HttpServletRequest request, MessageBuilder builder) { diff --git a/org.springframework.integration.http/src/test/java/org/springframework/integration/http/DefaultInboundRequestMapperTests.java b/org.springframework.integration.http/src/test/java/org/springframework/integration/http/DefaultInboundRequestMapperTests.java index 3fe184c8ca..d8b2eef822 100644 --- a/org.springframework.integration.http/src/test/java/org/springframework/integration/http/DefaultInboundRequestMapperTests.java +++ b/org.springframework.integration.http/src/test/java/org/springframework/integration/http/DefaultInboundRequestMapperTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2009 the original author or authors. + * Copyright 2002-2011 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.integration.http; +import static org.easymock.EasyMock.createMock; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServletRequest; + import org.junit.Test; import org.springframework.integration.core.Message; import org.springframework.mock.web.MockHttpServletRequest; @@ -81,4 +93,68 @@ public class DefaultInboundRequestMapperTests { assertThat(message.getPayload(), is(content)); } + @Test + public void largeStreamTest() throws Exception { + final int size = 100000; + final byte[] content = new byte[size]; + for (int i = 0; i < size; i++) { + content[i] = 7; + } + HttpServletRequest request = createMock(HttpServletRequest.class); + final InputStream inputStream = new InputStream() { + private final AtomicInteger next = new AtomicInteger(0); + public int read() throws IOException { + if (next.get() >= size) { + return -1; + } + int index = next.getAndIncrement(); + return content[index]; + } + }; + expect(request.getInputStream()).andReturn(new ServletInputStream() { + public int read(byte[] b) throws IOException { + int length = b.length; + if (length > 99990) { + length = 99990; + } + byte[] temp = new byte[length]; + int numRead = inputStream.read(temp); + for (int i = 0; i < numRead; i++) { + b[i] = temp[i]; + } + return numRead; + } + public int read(byte[] b, int off, int length) throws IOException { + if (length > 99990) { + length = 99990; + } + byte[] temp = new byte[length]; + int numRead = inputStream.read(temp, 0, length); + for (int i = 0; i < numRead; i++) { + b[off + i] = temp[i]; + } + return numRead; + } + public int read() throws IOException { + return inputStream.read(); + } + }); + expect(request.getContentType()).andReturn(null); + expect(request.getMethod()).andReturn("POST").anyTimes(); + expect(request.getContentLength()).andReturn(size); + expect(request.getHeaderNames()).andReturn(null); + expect(request.getRequestURL()).andReturn(new StringBuffer("http://test")); + expect(request.getUserPrincipal()).andReturn(null); + replay(request); + Message message = mapper.toMessage(request); + byte[] payload = (byte[]) message.getPayload(); + int byteValueCounter = 0; + for (byte b : payload) { + if (b == 7) { + byteValueCounter++; + } + } + assertEquals(size, byteValueCounter); + } + }