INT-1396 DefaultInboundRequestMapper now properly handles reading large byte payloads from request input streams

This commit is contained in:
Mark Fisher
2011-02-01 14:35:32 -05:00
parent 9875e3f134
commit 2437db435f
2 changed files with 88 additions and 6 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2009 the original author or authors.
* Copyright 2002-2011 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
package org.springframework.integration.http;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
@@ -225,7 +226,8 @@ public class DefaultInboundRequestMapper implements InboundRequestMapper {
}
private byte[] createPayloadFromInputStream(HttpServletRequest request) throws Exception {
InputStream stream = request.getInputStream();
int bufferSize = 4096;
InputStream in = request.getInputStream();
int length = request.getContentLength();
if (length == -1) {
throw new ResponseStatusCodeException(HttpServletResponse.SC_LENGTH_REQUIRED);
@@ -234,9 +236,13 @@ public class DefaultInboundRequestMapper implements InboundRequestMapper {
logger.debug("received " + request.getMethod() + " request, "
+ "creating byte array payload with content lenth: " + length);
}
byte[] bytes = new byte[length];
stream.read(bytes, 0, length);
return bytes;
ByteArrayOutputStream out = new ByteArrayOutputStream(bufferSize);
byte[] buffer = new byte[bufferSize];
int bytesRead = -1;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
return out.toByteArray();
}
private void populateHeaders(HttpServletRequest request, MessageBuilder<?> builder) {

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2009 the original author or authors.
* Copyright 2002-2011 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,11 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.http;
import static org.easymock.EasyMock.createMock;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
import org.springframework.integration.core.Message;
import org.springframework.mock.web.MockHttpServletRequest;
@@ -81,4 +93,68 @@ public class DefaultInboundRequestMapperTests {
assertThat(message.getPayload(), is(content));
}
@Test
public void largeStreamTest() throws Exception {
final int size = 100000;
final byte[] content = new byte[size];
for (int i = 0; i < size; i++) {
content[i] = 7;
}
HttpServletRequest request = createMock(HttpServletRequest.class);
final InputStream inputStream = new InputStream() {
private final AtomicInteger next = new AtomicInteger(0);
public int read() throws IOException {
if (next.get() >= size) {
return -1;
}
int index = next.getAndIncrement();
return content[index];
}
};
expect(request.getInputStream()).andReturn(new ServletInputStream() {
public int read(byte[] b) throws IOException {
int length = b.length;
if (length > 99990) {
length = 99990;
}
byte[] temp = new byte[length];
int numRead = inputStream.read(temp);
for (int i = 0; i < numRead; i++) {
b[i] = temp[i];
}
return numRead;
}
public int read(byte[] b, int off, int length) throws IOException {
if (length > 99990) {
length = 99990;
}
byte[] temp = new byte[length];
int numRead = inputStream.read(temp, 0, length);
for (int i = 0; i < numRead; i++) {
b[off + i] = temp[i];
}
return numRead;
}
public int read() throws IOException {
return inputStream.read();
}
});
expect(request.getContentType()).andReturn(null);
expect(request.getMethod()).andReturn("POST").anyTimes();
expect(request.getContentLength()).andReturn(size);
expect(request.getHeaderNames()).andReturn(null);
expect(request.getRequestURL()).andReturn(new StringBuffer("http://test"));
expect(request.getUserPrincipal()).andReturn(null);
replay(request);
Message<?> message = mapper.toMessage(request);
byte[] payload = (byte[]) message.getPayload();
int byteValueCounter = 0;
for (byte b : payload) {
if (b == 7) {
byteValueCounter++;
}
}
assertEquals(size, byteValueCounter);
}
}