diff --git a/spring-core/src/main/java/org/springframework/util/DigestUtils.java b/spring-core/src/main/java/org/springframework/util/DigestUtils.java index a4ca743bf9..29dcb54105 100644 --- a/spring-core/src/main/java/org/springframework/util/DigestUtils.java +++ b/spring-core/src/main/java/org/springframework/util/DigestUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.util; +import java.io.IOException; +import java.io.InputStream; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -45,6 +47,16 @@ public abstract class DigestUtils { return digest(MD5_ALGORITHM_NAME, bytes); } + /** + * Calculate the MD5 digest of the given InputStream. + * @param inputStream the inputStream to calculate the digest over + * @return the digest + * @since 4.2 + */ + public static byte[] md5Digest(InputStream inputStream) throws IOException{ + return digest(MD5_ALGORITHM_NAME, inputStream); + } + /** * Return a hexadecimal string representation of the MD5 digest of the given * bytes. @@ -55,6 +67,17 @@ public abstract class DigestUtils { return digestAsHexString(MD5_ALGORITHM_NAME, bytes); } + /** + * Return a hexadecimal string representation of the MD5 digest of the given + * inputStream. + * @param inputStream the inputStream to calculate the digest over + * @return a hexadecimal digest string + * @since 4.2 + */ + public static String md5DigestAsHex(InputStream inputStream) throws IOException{ + return digestAsHexString(MD5_ALGORITHM_NAME, inputStream); + } + /** * Append a hexadecimal string representation of the MD5 digest of the given * bytes to the given {@link StringBuilder}. @@ -66,6 +89,18 @@ public abstract class DigestUtils { return appendDigestAsHex(MD5_ALGORITHM_NAME, bytes, builder); } + /** + * Append a hexadecimal string representation of the MD5 digest of the given + * inputStream to the given {@link StringBuilder}. + * @param inputStream the inputStream to calculate the digest over + * @param builder the string builder to append the digest to + * @return the given string builder + * @since 4.2 + */ + public static StringBuilder appendMd5DigestAsHex(InputStream inputStream, StringBuilder builder) throws IOException{ + return appendDigestAsHex(MD5_ALGORITHM_NAME, inputStream, builder); + } + /** * Creates a new {@link MessageDigest} with the given algorithm. Necessary * because {@code MessageDigest} is not thread-safe. @@ -83,21 +118,46 @@ public abstract class DigestUtils { return getDigest(algorithm).digest(bytes); } + private static byte[] digest(String algorithm, InputStream inputStream) throws IOException{ + MessageDigest messageDigest = getDigest(algorithm); + if(inputStream instanceof UpdateMessageDigestInputStream){ + ((UpdateMessageDigestInputStream) inputStream).updateMessageDigest(messageDigest); + return messageDigest.digest(); + }else{ + return messageDigest.digest(StreamUtils.copyToByteArray(inputStream)); + } + } + private static String digestAsHexString(String algorithm, byte[] bytes) { char[] hexDigest = digestAsHexChars(algorithm, bytes); return new String(hexDigest); } + private static String digestAsHexString(String algorithm, InputStream inputStream) throws IOException{ + char[] hexDigest = digestAsHexChars(algorithm, inputStream); + return new String(hexDigest); + } + private static StringBuilder appendDigestAsHex(String algorithm, byte[] bytes, StringBuilder builder) { char[] hexDigest = digestAsHexChars(algorithm, bytes); return builder.append(hexDigest); } + private static StringBuilder appendDigestAsHex(String algorithm, InputStream inputStream, StringBuilder builder) throws IOException{ + char[] hexDigest = digestAsHexChars(algorithm, inputStream); + return builder.append(hexDigest); + } + private static char[] digestAsHexChars(String algorithm, byte[] bytes) { byte[] digest = digest(algorithm, bytes); return encodeHex(digest); } + private static char[] digestAsHexChars(String algorithm, InputStream inputStream) throws IOException{ + byte[] digest = digest(algorithm, inputStream); + return encodeHex(digest); + } + private static char[] encodeHex(byte[] bytes) { char chars[] = new char[32]; for (int i = 0; i < chars.length; i = i + 2) { diff --git a/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java b/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java new file mode 100644 index 0000000000..6c6eb46e3a --- /dev/null +++ b/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java @@ -0,0 +1,544 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.MessageDigest; +import java.util.Iterator; +import java.util.LinkedList; + +/** + * A speedy alternative to {@link java.io.ByteArrayOutputStream}. + *

Unlike {@link java.io.ByteArrayOutputStream}, this implementation is backed by a + * {@link java.util.LinkedList} of byte[] instead of 1 constantly resizing byte[]. + * It does not copy buffers when it's expanded.

+ * + *

The initial buffer is only created when the stream is first written. + * There's also no copying of the internal buffer if its contents is extracted with the + * {@link #writeTo(OutputStream)} method. + * Instances of this class are NOT THREAD SAFE.

+ * + * @author Craig Andrews + * @since 4.2 + */ +public final class FastByteArrayOutputStream extends OutputStream { + + private static final int DEFAULT_BLOCK_SIZE = 256; + + // the buffers used to store the content bytes + private final LinkedList buffers = new LinkedList(); + + // is the stream closed? + private boolean closed = false; + + // the size, in bytes, to use when allocating the next next byte[] + private int nextBlockSize; + + // the index in the byte[] found at buffers.getLast() to be written next + private int index = 0; + + // number of bytes in previous buffers + // the number of bytes in the current buffer is in index + private int alreadyBufferedSize = 0; + + // the size, in bytes, to use when allocating the first byte[] + private final int initialBlockSize; + + /** + * Create a new FastByteArrayOutputStream + * with the default initial capacity of {@value #DEFAULT_BLOCK_SIZE} bytes. + */ + public FastByteArrayOutputStream() { + this(DEFAULT_BLOCK_SIZE); + } + + /** + * Create a new FastByteArrayOutputStream + * with the specified initial capacity. + * @param initialBlockSize the initial buffer size in bytes + */ + public FastByteArrayOutputStream(int initialBlockSize) { + Assert.isTrue(initialBlockSize > 0, "Initial block size must be greater than 0"); + this.initialBlockSize = initialBlockSize; + nextBlockSize = initialBlockSize; + } + + @Override + public void write(int datum) throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + else { + if (buffers.peekLast() == null || buffers.getLast().length == index) { + addBuffer(1); + } + // store the byte + buffers.getLast()[index++] = (byte) datum; + } + } + + @Override + public void write(byte[] data, int offset, int length) throws IOException { + if (data == null) { + throw new NullPointerException(); + } + else if ((offset < 0) || ((offset + length) > data.length) || (length < 0)) { + throw new IndexOutOfBoundsException(); + } + else if (closed) { + throw new IOException("Stream closed"); + } + else { + if (buffers.peekLast() == null || buffers.getLast().length == index) { + addBuffer(length); + } + if ((index + length) > buffers.getLast().length) { + do { + if (index == buffers.getLast().length) { + addBuffer(length); + } + int copyLength = buffers.getLast().length - index; + if (length < copyLength) { + copyLength = length; + } + System.arraycopy(data, offset, buffers.getLast(), index, copyLength); + offset += copyLength; + index += copyLength; + length -= copyLength; + } while (length > 0); + } + else { + // Copy in the subarray + System.arraycopy(data, offset, buffers.getLast(), index, length); + index += length; + } + } + } + + @Override + public void close() { + closed = true; + } + + /** + * Returns the number of bytes stored in this FastByteArrayOutputStream + */ + public int size() { + return alreadyBufferedSize + index; + } + + /** + * Convert the stream's data to a byte array and return the byte array. + * + *

Also replaces the internal structures with the byte array to conserve memory: + * if the byte array is being made anyways, mind as well as use it. + * This approach also means that if this method is called twice without any writes in between, + * the second call is a no-op. + * This method is "unsafe" as it returns the internal buffer - callers should not modify the returned buffer.

+ * + * @return the current contents of this output stream, as a byte array. + * @see #size() + * @see #toByteArray() + */ + public byte[] toByteArrayUnsafe() { + int totalSize = size(); + if (totalSize == 0) { + return new byte[0]; + } + resize(totalSize); + return buffers.getFirst(); + } + + /** + * Creates a newly allocated byte array. + * + *

Its size is the current + * size of this output stream and the valid contents of the buffer + * have been copied into it.

+ * + * @return the current contents of this output stream, as a byte array. + * @see #size() + * @see #toByteArrayUnsafe() + */ + public byte[] toByteArray() { + byte[] bytesUnsafe = toByteArrayUnsafe(); + byte[] ret = new byte[bytesUnsafe.length]; + System.arraycopy(bytesUnsafe, 0, ret, 0, bytesUnsafe.length); + return ret; + } + + /** + * Resets the contents of this FastByteArrayOutputStreamInputStream + *

All currently accumulated output in the output stream is discarded. + * The output stream can be used again.

+ */ + public void reset() { + buffers.clear(); + nextBlockSize = initialBlockSize; + closed = false; + index = 0; + alreadyBufferedSize = 0; + } + + /** + * Get an {@link java.io.InputStream} to retrieve the data in this OutputStream + * + *

Note that if any methods are called on the OutputStream + * (including, but not limited to, any of the write methods, {@link #reset()}, + * {@link #toByteArray()}, and {@link #toByteArrayUnsafe()}) then the {@link java.io.InputStream}'s + * behavior is undefined.

+ * + * @return {@link java.io.InputStream} of the contents of this FastByteArrayOutputStreamInputStream + */ + public InputStream getInputStream() { + return new FastByteArrayOutputStreamInputStream(this); + } + + /** + * Write the buffers content to the given OutputStream + * + * @param out the OutputStream to write to + */ + public void writeTo(OutputStream out) throws IOException { + Iterator iter = buffers.iterator(); + + while (iter.hasNext()) { + byte[] bytes = iter.next(); + if (iter.hasNext()) { + out.write(bytes, 0, bytes.length); + } + else { + out.write(bytes, 0, index); + } + } + } + + /** + * Resize the internal buffer size to a specified capacity. + * + * @param targetCapacity the desired size of the buffer + * @throws IllegalArgumentException if the given capacity is smaller than + * the actual size of the content stored in the buffer already + * @see FastByteArrayOutputStream#size() + */ + public void resize(int targetCapacity) { + Assert.isTrue(targetCapacity >= size(), "New capacity must not be smaller than current size"); + if (buffers.peekFirst() == null) { + nextBlockSize = targetCapacity - size(); + } + else if (size() == targetCapacity && buffers.getFirst().length == targetCapacity) { + // do nothing - already at the targetCapacity + } + else { + int totalSize = size(); + byte[] data = new byte[targetCapacity]; + int pos = 0; + Iterator iter = buffers.iterator(); + while (iter.hasNext()) { + byte[] bytes = iter.next(); + if (iter.hasNext()) { + System.arraycopy(bytes, 0, data, pos, bytes.length); + pos += bytes.length; + } + else { + System.arraycopy(bytes, 0, data, pos, index); + } + } + buffers.clear(); + buffers.add(data); + index = totalSize; + alreadyBufferedSize = 0; + } + } + + /** + * Create a new buffer and store it in the LinkedList + * + *

Adds a new buffer that can store at least {@code minCapacity} bytes

+ */ + private void addBuffer(int minCapacity) { + if (buffers.peekLast() != null) { + alreadyBufferedSize += index; + index = 0; + } + if (nextBlockSize < minCapacity) { + nextBlockSize = nextPowerOf2(minCapacity); + } + buffers.add(new byte[nextBlockSize]); + nextBlockSize *= 2; // block size doubles each time + } + + /** + * Get the next power of 2 of a number (ex, the next power of 2 of 119 is 128) + */ + private static final int nextPowerOf2(int val) { + val--; + val = (val >> 1) | val; + val = (val >> 2) | val; + val = (val >> 4) | val; + val = (val >> 8) | val; + val = (val >> 16) | val; + val++; + return val; + } + + /** + * Converts the buffer's contents into a string decoding bytes using the + * platform's default character set. The length of the new String + * is a function of the character set, and hence may not be equal to the + * size of the buffer. + * + *

This method always replaces malformed-input and unmappable-character + * sequences with the default replacement string for the platform's + * default character set. The {@linkplain java.nio.charset.CharsetDecoder} + * class should be used when more control over the decoding process is + * required.

+ * + * @return String decoded from the buffer's contents. + */ + @Override + public String toString() { + return new String(toByteArrayUnsafe()); + } + + + /** + * An implementation of {@link java.io.InputStream} that reads from FastByteArrayOutputStream + * Instances of this class are NOT THREAD SAFE. + */ + private static final class FastByteArrayOutputStreamInputStream extends UpdateMessageDigestInputStream { + int totalBytesRead = 0; + + int nextIndexInCurrentBuffer = 0; + + final Iterator buffersIterator; + + byte[] currentBuffer; + + int currentBufferLength; + + final FastByteArrayOutputStream fastByteArrayOutputStream; + + /** + * Create a new FastByteArrayOutputStreamInputStream backed + * by the given FastByteArrayOutputStream + */ + public FastByteArrayOutputStreamInputStream(FastByteArrayOutputStream fastByteArrayOutputStream) { + this.fastByteArrayOutputStream = fastByteArrayOutputStream; + buffersIterator = fastByteArrayOutputStream.buffers.iterator(); + if (buffersIterator.hasNext()) { + currentBuffer = buffersIterator.next(); + if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) { + currentBufferLength = fastByteArrayOutputStream.index; + } + else { + currentBufferLength = currentBuffer.length; + } + } + else { + currentBuffer = null; + } + } + + @Override + public int read() { + if (currentBuffer == null) { + // this stream doesn't have any data in it + return -1; + } + else { + if (nextIndexInCurrentBuffer < currentBufferLength) { + totalBytesRead++; + return currentBuffer[nextIndexInCurrentBuffer++]; + } + else { + if (buffersIterator.hasNext()) { + currentBuffer = buffersIterator.next(); + if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) { + currentBufferLength = fastByteArrayOutputStream.index; + } + else { + currentBufferLength = currentBuffer.length; + } + nextIndexInCurrentBuffer = 0; + } + else { + currentBuffer = null; + } + return read(); + } + } + } + + @Override + public int read(byte[] b) { + return read(b, 0, b.length); + } + + @Override + public int read(byte[] b, int off, int len) { + if (b == null) { + throw new NullPointerException(); + } + else if (off < 0 || len < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException(); + } + else if (len == 0) { + return 0; + } + else if (len < 0) { + throw new IllegalArgumentException("len must be 0 or greater: " + len); + } + else if (off < 0) { + throw new IllegalArgumentException("off must be 0 or greater: " + off); + } + else { + if (currentBuffer == null) { + // this stream doesn't have any data in it + return 0; + } + else { + if (nextIndexInCurrentBuffer < currentBufferLength) { + int bytesToCopy = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer); + System.arraycopy(currentBuffer, nextIndexInCurrentBuffer, b, off, bytesToCopy); + totalBytesRead += bytesToCopy; + nextIndexInCurrentBuffer += bytesToCopy; + return bytesToCopy + read(b, off + bytesToCopy, len - bytesToCopy); + } + else { + if (buffersIterator.hasNext()) { + currentBuffer = buffersIterator.next(); + if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) { + currentBufferLength = fastByteArrayOutputStream.index; + } + else { + currentBufferLength = currentBuffer.length; + } + nextIndexInCurrentBuffer = 0; + } + else { + currentBuffer = null; + } + return read(b, off, len); + } + } + } + } + + @Override + public long skip(long n) throws IOException { + if (n > Integer.MAX_VALUE) { + throw new IllegalArgumentException("n exceeds maximum (" + + Integer.MAX_VALUE + "): " + n); + } + else if (n == 0) { + return 0; + } + else if (n < 0) { + throw new IllegalArgumentException("n must be 0 or greater: " + n); + } + int len = (int) n; + if (currentBuffer == null) { + // this stream doesn't have any data in it + return 0; + } + else { + if (nextIndexInCurrentBuffer < currentBufferLength) { + int bytesToSkip = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer); + totalBytesRead += bytesToSkip; + nextIndexInCurrentBuffer += bytesToSkip; + return bytesToSkip + skip(len - bytesToSkip); + } + else { + if (buffersIterator.hasNext()) { + currentBuffer = buffersIterator.next(); + if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) { + currentBufferLength = fastByteArrayOutputStream.index; + } + else { + currentBufferLength = currentBuffer.length; + } + nextIndexInCurrentBuffer = 0; + } + else { + currentBuffer = null; + } + return skip(len); + } + } + } + + @Override + public int available() { + return fastByteArrayOutputStream.size() - totalBytesRead; + } + + /** + * Update the message digest with the remaining bytes in this stream. + * + * @param messageDigest The message digest to update + */ + public void updateMessageDigest(MessageDigest messageDigest) { + updateMessageDigest(messageDigest, available()); + } + + /** + * Update the message digest with the next len bytes in this stream. + * Avoids creating new byte arrays and use internal buffers for performance. + * @param messageDigest The message digest to update + * @param len how many bytes to read from this stream and use to update the message digest + */ + public void updateMessageDigest(MessageDigest messageDigest, int len) { + if (currentBuffer == null) { + // this stream doesn't have any data in it + return; + } + else if (len == 0) { + return; + } + else if (len < 0) { + throw new IllegalArgumentException("len must be 0 or greater: " + len); + } + else { + if (nextIndexInCurrentBuffer < currentBufferLength) { + int bytesToCopy = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer); + messageDigest.update(currentBuffer, nextIndexInCurrentBuffer, bytesToCopy); + nextIndexInCurrentBuffer += bytesToCopy; + updateMessageDigest(messageDigest, len - bytesToCopy); + } + else { + if (buffersIterator.hasNext()) { + currentBuffer = buffersIterator.next(); + if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) { + currentBufferLength = fastByteArrayOutputStream.index; + } + else { + currentBufferLength = currentBuffer.length; + } + nextIndexInCurrentBuffer = 0; + } + else { + currentBuffer = null; + } + updateMessageDigest(messageDigest, len); + } + } + } + } +} diff --git a/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java b/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java new file mode 100644 index 0000000000..8ae173627e --- /dev/null +++ b/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import java.io.IOException; +import java.io.InputStream; +import java.security.MessageDigest; + +/** + * Extension of {@link java.io.InputStream} that allows for optimized + * implementations of message digesting. + * + * @author Craig Andrews + * @since 4.2 + */ +public abstract class UpdateMessageDigestInputStream extends InputStream { + + /** + * Update the message digest with the rest of the bytes in this stream + * + *

Using this method is more optimized since it avoids creating new byte arrays for each call.

+ * + * @param messageDigest The message digest to update + * @throws IOException + */ + public void updateMessageDigest(MessageDigest messageDigest) throws IOException{ + int data; + while((data = read()) != -1){ + messageDigest.update((byte)data); + } + } + + /** + * Update the message digest with the next len bytes in this stream + * + *

Using this method is more optimized since it avoids creating new byte arrays for each call.

+ * + * @param messageDigest The message digest to update + * @param len how many bytes to read from this stream and use to update the message digest + * @throws IOException + */ + public void updateMessageDigest(MessageDigest messageDigest, int len) throws IOException{ + int data; + int bytesRead = 0; + while(bytesRead < len && (data = read()) != -1){ + messageDigest.update((byte)data); + bytesRead++; + } + } +} diff --git a/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java b/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java new file mode 100644 index 0000000000..9eedeca8ee --- /dev/null +++ b/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import static org.junit.Assert.*; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.junit.Before; +import org.junit.Test; + +/** + * Test suite for {@link FastByteArrayOutputStream} + * @author Craig Andrews + */ +public class FastByteArrayOutputStreamTests { + + private static final int INITIAL_CAPACITY = 256; + + private FastByteArrayOutputStream os; + + private byte[] helloBytes; + + @Before + public void setUp() throws Exception { + this.os = new FastByteArrayOutputStream(INITIAL_CAPACITY); + this.helloBytes = "Hello World".getBytes("UTF-8"); + } + + @Test + public void size() throws Exception { + this.os.write(helloBytes); + assertEquals(this.os.size(), helloBytes.length); + } + + @Test + public void resize() throws Exception { + this.os.write(helloBytes); + int sizeBefore = this.os.size(); + this.os.resize(64); + assertByteArrayEqualsString(this.os); + assertEquals(sizeBefore, this.os.size()); + } + + @Test + public void autoGrow() throws IOException { + this.os.resize(1); + for (int i = 0; i < 10; i++) { + this.os.write(1); + } + assertEquals(10, this.os.size()); + assertArrayEquals(this.os.toByteArray(), new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + } + + @Test + public void write() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + } + + @Test + public void reset() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + this.os.reset(); + assertEquals(0, this.os.size()); + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + } + + @Test(expected = IOException.class) + public void close() throws Exception { + this.os.close(); + this.os.write(helloBytes); + } + + @Test + public void toByteArrayUnsafe() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + assertSame(this.os.toByteArrayUnsafe(), this.os.toByteArrayUnsafe()); + assertArrayEquals(this.os.toByteArray(), helloBytes); + } + + @Test + public void writeTo() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + this.os.writeTo(baos); + assertArrayEquals(baos.toByteArray(), helloBytes); + } + + @Test(expected = IllegalArgumentException.class) + public void failResize() throws Exception { + this.os.write(helloBytes); + this.os.resize(5); + } + + @Test + public void getInputStream() throws Exception { + this.os.write(helloBytes); + assertNotNull(this.os.getInputStream()); + } + + @Test + public void getInputStreamAvailable() throws Exception { + this.os.write(helloBytes); + assertEquals(this.os.getInputStream().available(), helloBytes.length); + } + + @Test + public void getInputStreamRead() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.read(), helloBytes[0]); + assertEquals(inputStream.read(), helloBytes[1]); + assertEquals(inputStream.read(), helloBytes[2]); + assertEquals(inputStream.read(), helloBytes[3]); + } + + @Test + public void getInputStreamReadAll() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + byte[] actual = new byte[inputStream.available()]; + int bytesRead = inputStream.read(actual); + assertEquals(bytesRead, helloBytes.length); + assertArrayEquals(actual, helloBytes); + assertEquals(0, inputStream.available()); + } + + @Test + public void getInputStreamSkip() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.read(), helloBytes[0]); + assertEquals(inputStream.skip(1), 1); + assertEquals(inputStream.read(), helloBytes[2]); + assertEquals(helloBytes.length - 3, inputStream.available()); + } + + @Test + public void getInputStreamSkipAll() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.skip(1000), helloBytes.length); + assertEquals(0, inputStream.available()); + } + + @Test + public void updateMessageDigest() throws Exception { + StringBuilder builder = new StringBuilder("\"0"); + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + builder.append("\""); + String actual = builder.toString(); + assertEquals("\"0b10a8db164e0754105b7a99be72e3fe5\"", actual); + } + + @Test + public void updateMessageDigestManyBuffers() throws Exception { + StringBuilder builder = new StringBuilder("\"0"); + // filling at least one 256 buffer + for( int i=0; i < 30; i++) { + this.os.write(helloBytes); + } + InputStream inputStream = this.os.getInputStream(); + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + builder.append("\""); + String actual = builder.toString(); + assertEquals("\"06225ca1e4533354c516e74512065331d\"", actual); + } + + private void assertByteArrayEqualsString(FastByteArrayOutputStream actual) { + assertArrayEquals(helloBytes, actual.toByteArray()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index 8e8fa9d028..31e66eddf9 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.filter; import java.io.IOException; +import java.io.InputStream; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -26,7 +27,6 @@ import org.springframework.http.HttpMethod; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.DigestUtils; -import org.springframework.util.StreamUtils; import org.springframework.web.util.ContentCachingResponseWrapper; import org.springframework.web.util.WebUtils; @@ -93,15 +93,12 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { HttpServletResponse rawResponse = (HttpServletResponse) responseWrapper.getResponse(); int statusCode = responseWrapper.getStatusCode(); - byte[] body = responseWrapper.getContentAsByteArray(); if (rawResponse.isCommitted()) { - if (body.length > 0) { - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } - else if (isEligibleForEtag(request, responseWrapper, statusCode, body)) { - String responseETag = generateETagHeaderValue(body); + else if (isEligibleForEtag(request, responseWrapper, statusCode, responseWrapper.getContentInputStream())) { + String responseETag = generateETagHeaderValue(responseWrapper.getContentInputStream()); rawResponse.setHeader(HEADER_ETAG, responseETag); String requestETag = request.getHeader(HEADER_IF_NONE_MATCH); if (responseETag.equals(requestETag)) { @@ -115,20 +112,14 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { logger.trace("ETag [" + responseETag + "] not equal to If-None-Match [" + requestETag + "], sending normal response"); } - if (body.length > 0) { - rawResponse.setContentLength(body.length); - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } } else { if (logger.isTraceEnabled()) { logger.trace("Response with status code [" + statusCode + "] not eligible for ETag"); } - if (body.length > 0) { - rawResponse.setContentLength(body.length); - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } } @@ -143,11 +134,11 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { * @param request the HTTP request * @param response the HTTP response * @param responseStatusCode the HTTP response status code - * @param responseBody the response body + * @param inputStream the response body * @return {@code true} if eligible for ETag generation; {@code false} otherwise */ protected boolean isEligibleForEtag(HttpServletRequest request, HttpServletResponse response, - int responseStatusCode, byte[] responseBody) { + int responseStatusCode, InputStream inputStream) { if (responseStatusCode >= 200 && responseStatusCode < 300 && HttpMethod.GET.name().equals(request.getMethod())) { @@ -162,13 +153,18 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { /** * Generate the ETag header value from the given response body byte array. *

The default implementation generates an MD5 hash. - * @param bytes the response body as byte array + * @param inputStream the response body as an InputStream * @return the ETag header value * @see org.springframework.util.DigestUtils */ - protected String generateETagHeaderValue(byte[] bytes) { + protected String generateETagHeaderValue(InputStream inputStream) { StringBuilder builder = new StringBuilder("\"0"); - DigestUtils.appendMd5DigestAsHex(bytes, builder); + try { + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + } + catch (IOException e) { + throw new RuntimeException(e); + } builder.append('"'); return builder.toString(); } diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 3eee67e72c..f28d5b2b52 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.util; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; @@ -24,8 +25,7 @@ import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; -import org.springframework.util.ResizableByteArrayOutputStream; -import org.springframework.util.StreamUtils; +import org.springframework.util.FastByteArrayOutputStream; /** * {@link javax.servlet.http-HttpServletResponse} wrapper that caches all content written to @@ -39,7 +39,7 @@ import org.springframework.util.StreamUtils; */ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { - private final ResizableByteArrayOutputStream content = new ResizableByteArrayOutputStream(1024); + private final FastByteArrayOutputStream content = new FastByteArrayOutputStream(1024); private final ServletOutputStream outputStream = new ResponseServletOutputStream(); @@ -107,9 +107,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Override public void setContentLength(int len) { - if (len > this.content.capacity()) { - this.content.resize(len); - } + this.content.resize(len); } // Overrides Servlet 3.1 setContentLengthLong(long) at runtime @@ -118,16 +116,12 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { throw new IllegalArgumentException("Content-Length exceeds ShallowEtagHeaderFilter's maximum (" + Integer.MAX_VALUE + "): " + len); } - if (len > this.content.capacity()) { - this.content.resize((int) len); - } + this.content.resize((int) len); } @Override public void setBufferSize(int size) { - if (size > this.content.capacity()) { - this.content.resize(size); - } + this.content.resize(size); } @Override @@ -142,7 +136,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { } /** - * Return the status code as specifed on the response. + * Return the status code as specified on the response. */ public int getStatusCode() { return this.statusCode; @@ -155,14 +149,24 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { return this.content.toByteArray(); } - private void copyBodyToResponse() throws IOException { + public void copyBodyToResponse() throws IOException { if (this.content.size() > 0) { - getResponse().setContentLength(this.content.size()); - StreamUtils.copy(this.content.toByteArray(), getResponse().getOutputStream()); + HttpServletResponse rawResponse = (HttpServletResponse) getResponse(); + if(! rawResponse.isCommitted()){ + rawResponse.setContentLength(this.content.size()); + } + this.content.writeTo(rawResponse.getOutputStream()); this.content.reset(); } } + public int getContentSize(){ + return this.content.size(); + } + + public InputStream getContentInputStream(){ + return this.content.getInputStream(); + } private class ResponseServletOutputStream extends ServletOutputStream { diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index 74d501147b..da796ac3f8 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.filter; +import java.io.ByteArrayInputStream; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -51,15 +52,15 @@ public class ShallowEtagHeaderFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); - assertTrue(filter.isEligibleForEtag(request, response, 200, new byte[0])); - assertFalse(filter.isEligibleForEtag(request, response, 300, new byte[0])); + assertTrue(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); + assertFalse(filter.isEligibleForEtag(request, response, 300, new ByteArrayInputStream(new byte[0]))); request = new MockHttpServletRequest("POST", "/hotels"); - assertFalse(filter.isEligibleForEtag(request, response, 200, new byte[0])); + assertFalse(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); request = new MockHttpServletRequest("POST", "/hotels"); request.addHeader("Cache-Control","must-revalidate, no-store"); - assertFalse(filter.isEligibleForEtag(request, response, 200, new byte[0])); + assertFalse(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); } @Test